目录

一、上周回顾与本周目标

上周完成的工作

模块 功能
项目初始化 环境搭建、依赖安装
核心架构 分层模块化设计
工作流引擎 LangGraph StateGraph
智能路由 规则引擎 + LLM
流式 API SSE 实时推送
小红书子图 内容→图片→发布流程
基础测试 单元测试

本周开发目标

  1. 优化工作流引擎:提升图执行性能,增强状态管理
  2. 扩展智能路由系统:支持更多路由策略,增强意图分析能力
  3. 完善流式 API:优化 SSE 协议实现,添加心跳机制
  4. 扩展小红书工作流:支持多平台发布,优化内容生成
  5. 增强错误处理:实现统一的错误处理和重试机制
  6. 编写集成测试:确保各模块协同工作正常

二、工作流引擎优化

1. 图执行器重构

重构目标

提高图执行器的性能和可扩展性,支持更多执行模式。

核心实现
class GraphExecutor:
    """图执行器 - 支持多种执行模式
    
    新增功能:
    1. 同步/异步执行支持
    2. 批量执行模式
    3. 执行超时控制
    4. 执行进度回调
    """
    
    def __init__(self, compiled_graph):
        self.graph = compiled_graph
        self._executor = None
    
    async def invoke(
        self,
        input_data: dict,
        thread_id: str,
        timeout: int = 300,
        progress_callback: Optional[callable] = None,
    ) -> dict:
        """执行图并返回结果
        
        Args:
            input_data: 输入数据
            thread_id: 会话 ID
            timeout: 超时时间(秒)
            progress_callback: 进度回调函数
        
        Returns:
            执行结果
        """
        config = {
            "configurable": {
                "thread_id": thread_id,
            }
        }
        
        # 执行超时控制
        try:
            async with asyncio.timeout(timeout):
                # 执行图
                result = await self.graph.ainvoke(input_data, config)
                
                # 进度回调
                if progress_callback:
                    await progress_callback(100, "执行完成")
                
                return result
                
        except asyncio.TimeoutError:
            raise GraphExecutionError(f"图执行超时({timeout}秒)")
    
    async def batch_invoke(
        self,
        inputs: list[dict],
        thread_ids: list[str],
        max_concurrent: int = 5,
    ) -> list[dict]:
        """批量执行图
        
        Args:
            inputs: 输入数据列表
            thread_ids: 会话 ID 列表
            max_concurrent: 最大并发数
        
        Returns:
            执行结果列表
        """
        # 限制并发数
        semaphore = asyncio.Semaphore(max_concurrent)
        
        async def execute_with_limit(input_data, thread_id):
            async with semaphore:
                return await self.invoke(input_data, thread_id)
        
        # 并行执行
        tasks = [
            execute_with_limit(inputs[i], thread_ids[i])
            for i in range(len(inputs))
        ]
        
        return await asyncio.gather(*tasks)

2. 状态管理增强

新增功能

支持多种存储后端和状态恢复机制。

class StateManager:
    """状态管理器 - 支持多种存储后端和高级特性"""
    
    def __init__(
        self,
        storage_backend: str = "memory",
        storage_config: Optional[dict] = None,
    ):
        """初始化状态管理器
        
        Args:
            storage_backend: 存储后端类型(memory/sqlite/redis)
            storage_config: 存储配置
        """
        self.backend = storage_backend
        self.config = storage_config or {}
        
        # 根据后端类型初始化存储
        if storage_backend == "memory":
            self._storage = MemoryStorage()
        elif storage_backend == "sqlite":
            db_path = self.config.get("db_path", "./state.db")
            self._storage = SQLiteStorage(db_path)
        elif storage_backend == "redis":
            redis_config = self.config.get("redis", {})
            self._storage = RedisStorage(**redis_config)
        else:
            raise ValueError(f"不支持的存储后端: {storage_backend}")
        
        # 状态缓存
        self._cache = {}
        self._cache_ttl = self.config.get("cache_ttl", 3600)
    
    async def save_state(self, thread_id: str, state: dict):
        """保存状态(带缓存)"""
        # 保存到后端
        await self._storage.save(thread_id, state)
        
        # 更新缓存
        self._cache[thread_id] = {
            "state": state,
            "timestamp": time.time(),
        }
        
        # 清理过期缓存
        await self._cleanup_cache()
    
    async def load_state(self, thread_id: str) -> Optional[dict]:
        """加载状态(优先缓存)"""
        # 检查缓存
        cached = self._cache.get(thread_id)
        if cached and (time.time() - cached["timestamp"]) < self._cache_ttl:
            return cached["state"]
        
        # 从后端加载
        state = await self._storage.load(thread_id)
        
        # 更新缓存
        if state:
            self._cache[thread_id] = {
                "state": state,
                "timestamp": time.time(),
            }
        
        return state
    
    async def _cleanup_cache(self):
        """清理过期缓存"""
        now = time.time()
        expired_keys = [
            key for key, value in self._cache.items()
            if (now - value["timestamp"]) > self._cache_ttl
        ]
        
        for key in expired_keys:
            del self._cache[key]
    
    async def get_session_stats(self, thread_id: str) -> dict:
        """获取会话统计信息"""
        state = await self.load_state(thread_id)
        
        if not state:
            return {}
        
        return {
            "thread_id": thread_id,
            "message_count": len(state.get("messages", [])),
            "task_count": len([t for t in state.get("tasks", []) if t]),
            "created_at": state.get("created_at"),
            "last_updated": state.get("updated_at"),
        }

3. 错误处理机制

统一错误处理
class GraphExecutionError(Exception):
    """图执行错误"""
    
    def __init__(self, message: str, error_code: Optional[int] = None, details: Optional[dict] = None):
        super().__init__(message)
        self.error_code = error_code
        self.details = details or {}

class RetryPolicy:
    """重试策略"""
    
    def __init__(
        self,
        max_retries: int = 3,
        initial_delay: float = 1.0,
        backoff_factor: float = 2.0,
        retry_on_exceptions: Optional[list[type]] = None,
    ):
        self.max_retries = max_retries
        self.initial_delay = initial_delay
        self.backoff_factor = backoff_factor
        self.retry_on_exceptions = retry_on_exceptions or []
    
    async def execute(
        self,
        func: callable,
        *args,
        **kwargs,
    ):
        """执行函数并应用重试策略"""
        last_exception = None
        
        for attempt in range(self.max_retries):
            try:
                return await func(*args, **kwargs)
            except Exception as e:
                last_exception = e
                
                # 检查是否应该重试
                if not self._should_retry(e):
                    raise e
                
                # 计算延迟
                delay = self._calculate_delay(attempt)
                logger.warning(f"操作失败,第 {attempt + 1}/{self.max_retries} 次重试,等待 {delay:.2f} 秒")
                
                await asyncio.sleep(delay)
        
        # 所有重试都失败
        raise GraphExecutionError(
            f"操作失败,已重试 {self.max_retries} 次",
            details={"last_error": str(last_exception)}
        )
    
    def _should_retry(self, exception: Exception) -> bool:
        """判断是否应该重试"""
        if not self.retry_on_exceptions:
            return True
        
        return isinstance(exception, tuple(self.retry_on_exceptions))
    
    def _calculate_delay(self, attempt: int) -> float:
        """计算重试延迟(指数退避)"""
        return self.initial_delay * (self.backoff_factor ** attempt)

# 使用示例
async def safe_execute_graph(executor, input_data, thread_id):
    """安全执行图(带重试)"""
    policy = RetryPolicy(
        max_retries=3,
        initial_delay=1.0,
        retry_on_exceptions=[GraphExecutionError, asyncio.TimeoutError],
    )
    
    return await policy.execute(
        executor.invoke,
        input_data,
        thread_id,
    )

三、智能路由系统扩展

1. 路由策略优化

新增路由策略
class RouterStrategy:
    """路由策略常量(扩展)"""
    RULE_FIRST = "rule_first"        # 规则优先
    LLM_FIRST = "llm_first"          # LLM 优先
    HYBRID = "hybrid"                 # 混合策略
    LLM_ONLY = "llm_only"            # 仅 LLM
    RULE_ONLY = "rule_only"          # 仅规则
    CONFIDENCE_THRESHOLD = "confidence_threshold"  # 置信度阈值策略
    ROUND_ROBIN = "round_robin"      # 轮询策略(负载均衡)

class RouterSystem:
    """路由系统 - 扩展支持更多策略"""
    
    async def _route_confidence_threshold(
        self,
        user_input: str,
        context: Optional[dict],
        messages: Optional[list],
    ) -> RouteDecision:
        """置信度阈值策略
        
        规则:
        1. 使用规则引擎和LLM同时分析
        2. 如果规则置信度 >= 阈值,使用规则结果
        3. 如果LLM置信度 >= 阈值,使用LLM结果
        4. 否则返回降级决策
        """
        threshold = 0.7
        
        # 规则匹配
        rule_decision = self.rule_engine.match(user_input, context)
        rule_confidence = rule_decision.confidence if rule_decision else 0.0
        
        # LLM分析
        llm_decision = None
        if self.intent_analyzer:
            llm_decision = await self.intent_analyzer.analyze(user_input, context, messages)
            llm_confidence = llm_decision.confidence
        
        # 根据置信度选择
        if rule_confidence >= threshold and llm_confidence >= threshold:
            # 两者都满足,优先规则
            return self._merge_decisions(rule_decision, llm_decision)
        
        elif rule_confidence >= threshold:
            return rule_decision
        
        elif llm_confidence >= threshold:
            return llm_decision
        
        else:
            return self._create_fallback_decision("Confidence below threshold")
    
    def _merge_decisions(self, rule_decision, llm_decision) -> RouteDecision:
        """合并两个决策"""
        return RouteDecision(
            intent=llm_decision.intent,
            confidence=min(rule_decision.confidence, llm_decision.confidence),
            reasoning=f"规则: {rule_decision.reasoning} | LLM: {llm_decision.reasoning}",
            response=llm_decision.response,
            extracted_params={**rule_decision.extracted_params, **llm_decision.extracted_params},
            target_nodes=rule_decision.target_nodes or llm_decision.target_nodes,
            should_wait=llm_decision.should_wait,
            metadata={"merged": True},
        )

2. 意图分析器增强

多模型支持和上下文感知
class IntentAnalyzer:
    """意图分析器 - 增强版"""
    
    def __init__(
        self,
        llm_providers: Optional[list[str]] = None,
        fallback_provider: str = "qwen-plus",
        temperature: float = 0.3,
    ):
        """初始化意图分析器
        
        Args:
            llm_providers: LLM 提供者列表(用于负载均衡)
            fallback_provider: 降级提供者
            temperature: 温度参数
        """
        self.providers = llm_providers or ["qwen-plus"]
        self.fallback_provider = fallback_provider
        self.temperature = temperature
        
        # 轮询索引
        self._provider_index = 0
        
        # 模型缓存
        self._models = {}
    
    async def analyze(
        self,
        user_input: str,
        context: Optional[dict] = None,
        messages: Optional[list] = None,
        use_fallback: bool = False,
    ) -> RouteDecision:
        """分析用户意图(带降级)"""
        provider = self._select_provider(use_fallback)
        
        try:
            return await self._analyze_with_provider(provider, user_input, context, messages)
        except Exception as e:
            logger.error(f"Provider {provider} failed: {e}")
            
            # 如果已经在使用降级,直接抛出异常
            if use_fallback:
                raise
            
            # 尝试降级
            logger.info(f"Falling back to {self.fallback_provider}")
            return await self.analyze(user_input, context, messages, use_fallback=True)
    
    def _select_provider(self, use_fallback: bool = False) -> str:
        """选择LLM提供者(轮询策略)"""
        if use_fallback:
            return self.fallback_provider
        
        # 轮询选择
        provider = self.providers[self._provider_index]
        self._provider_index = (self._provider_index + 1) % len(self.providers)
        return provider
    
    async def _analyze_with_provider(
        self,
        provider: str,
        user_input: str,
        context: Optional[dict],
        messages: Optional[list],
    ) -> RouteDecision:
        """使用指定提供者分析"""
        # 获取或创建模型
        if provider not in self._models:
            self._models[provider] = self._create_model(provider)
        
        llm = self._models[provider]
        structured_llm = llm.with_structured_output(IntentAnalysisOutput)
        
        # 构建消息
        msg_list = [SystemMessage(content=self._build_prompt())]
        
        if messages:
            msg_list.extend(messages[-5:])
        
        msg_list.append(HumanMessage(content=user_input))
        
        # 添加上下文信息
        if context:
            context_str = json.dumps(context, ensure_ascii=False)
            msg_list.append(SystemMessage(content=f"上下文信息:{context_str}"))
        
        output = await structured_llm.ainvoke(msg_list)
        return self._convert_to_decision(output, provider)

3. 规则引擎扩展

支持复杂规则和动态规则
class RuleEngine:
    """规则引擎 - 支持复杂规则"""
    
    def __init__(self):
        self.rules = []
        self.functions = {}  # 自定义函数
    
    def add_rule(self, rule: RuleConfig):
        """添加规则"""
        self.rules.append(rule)
    
    def add_dynamic_rule(self, condition: callable, action: callable):
        """添加动态规则"""
        """
        动态规则示例:
        
        def condition(user_input, context):
            return "紧急" in user_input and context.get("priority") == "high"
        
        def action(user_input, context):
            return RouteDecision(
                intent="urgent_request",
                target_nodes=["priority_handler"],
                confidence=1.0,
            )
        
        rule_engine.add_dynamic_rule(condition, action)
        """
        dynamic_rule = {
            "condition": condition,
            "action": action,
        }
        self.rules.append(dynamic_rule)
    
    def register_function(self, name: str, func: callable):
        """注册自定义函数"""
        self.functions[name] = func
    
    def match(self, user_input: str, context: Optional[dict]) -> Optional[RouteDecision]:
        """匹配规则(支持动态规则)"""
        for rule in self.rules:
            if isinstance(rule, dict):
                # 动态规则
                if rule["condition"](user_input, context):
                    return rule["action"](user_input, context)
            else:
                # 静态规则
                if rule.match(user_input, context):
                    return self._create_decision(rule)
        
        return None

# 自定义规则示例
class TimeBasedRule(RouteConfig):
    """基于时间的规则"""
    
    def __init__(self, start_hour: int, end_hour: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.start_hour = start_hour
        self.end_hour = end_hour
    
    def match(self, user_input: str, context: Optional[dict]) -> bool:
        """检查当前时间是否在指定范围内"""
        current_hour = datetime.now().hour
        in_time_range = self.start_hour <= current_hour < self.end_hour
        
        # 必须满足时间条件和关键词条件
        return in_time_range and super().match(user_input, context)

# 使用示例
rule_engine = RuleEngine()

# 添加时间规则:工作时间(9:00-18:00)优先处理工作相关请求
work_hours_rule = TimeBasedRule(
    start_hour=9,
    end_hour=18,
    route_id="work_hours",
    keywords=["工作", "任务", "报告"],
    target_nodes=["work_handler"],
)
rule_engine.add_rule(work_hours_rule)

四、流式 API 完善

1. SSE 协议优化

增强的 SSE 实现
def create_streaming_router(executor: StreamingGraphExecutor) -> APIRouter:
    """创建流式 API 路由器(增强版)"""
    router = APIRouter(prefix="/api/v1", tags=["streaming"])
    
    @router.get("/chat/stream")
    async def chat_stream(
        message: str = Query(..., description="用户消息"),
        thread_id: Optional[str] = Query(None, description="会话ID"),
        user_id: Optional[str] = Query(None, description="用户ID"),
        stream_mode: str = Query("full", description="流式模式:full/light"),
    ):
        """流式聊天接口(增强版)
        
        流式模式说明:
        - full: 完整模式,包含所有节点事件
        - light: 轻量模式,仅包含关键事件
        """
        if not thread_id:
            thread_id = str(uuid.uuid4())
        
        return StreamingResponse(
            stream_graph_sse(
                executor=executor,
                user_input=message,
                thread_id=thread_id,
                user_id=user_id,
                stream_mode=stream_mode,
            ),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "X-Accel-Buffering": "no",
                "Access-Control-Allow-Origin": "*",
            },
        )
    
    @router.get("/chat/stream/status")
    async def stream_status(thread_id: str = Query(..., description="会话ID")):
        """查询流式会话状态"""
        status = await executor.get_session_status(thread_id)
        
        return {
            "thread_id": thread_id,
            "status": status,
            "timestamp": datetime.now().isoformat(),
        }
    
    return router

async def stream_graph_sse(
    executor: StreamingGraphExecutor,
    user_input: str,
    thread_id: str,
    user_id: Optional[str] = None,
    stream_mode: str = "full",
):
    """生成 SSE 事件流(增强版)"""
    
    async def event_generator():
        # 开始事件
        yield f"event: started\ndata: {json.dumps({'thread_id': thread_id, 'mode': stream_mode})}\n\n"
        
        # 启动心跳协程
        heartbeat_task = asyncio.create_task(_heartbeat_generator())
        
        try:
            async for event in executor.stream_async(
                user_input=user_input,
                thread_id=thread_id,
                user_id=user_id,
            ):
                event_type = event.get("type")
                event_data = event.get("data", {})
                
                # 轻量模式过滤
                if stream_mode == "light":
                    if event_type not in ["node_output", "completed", "error"]:
                        continue
                
                # 发送事件
                if event_type == "node_output":
                    # 压缩数据(移除不必要的字段)
                    event_data = _compress_event_data(event_data)
                
                yield f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
                
        except Exception as e:
            yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
        finally:
            heartbeat_task.cancel()
            yield f"event: completed\ndata: {json.dumps({'thread_id': thread_id})}\n\n"
    
    return event_generator()

async def _heartbeat_generator():
    """心跳生成器"""
    while True:
        yield ": heartbeat\n\n"
        await asyncio.sleep(15)

def _compress_event_data(data: dict) -> dict:
    """压缩事件数据"""
    # 移除不必要的字段
    fields_to_keep = ["node", "message", "progress", "timestamp"]
    
    return {k: v for k, v in data.items() if k in fields_to_keep}

2. 流式执行器增强

支持多模式执行
class StreamingGraphExecutor:
    """流式图执行器(增强版)"""
    
    def __init__(self, compiled_graph):
        self.graph = compiled_graph
        self._session_tasks = {}  # 会话任务映射
    
    async def stream_async(
        self,
        user_input: str,
        thread_id: str,
        user_id: Optional[str] = None,
        mode: str = "streaming",
    ):
        """异步流式执行
        
        执行模式:
        - streaming: 全流式输出
        - progressive: 渐进式输出(每个节点完成后输出)
        - batch: 批量输出(全部完成后一次输出)
        """
        config = {"configurable": {"thread_id": thread_id}}
        
        # 获取当前状态
        current_state = await self.graph.aget_state(config)
        
        # 构建输入
        input_data = self._build_input(user_input, current_state, user_id)
        
        if mode == "streaming":
            async for event in self._streaming_execute(input_data, config):
                yield event
        
        elif mode == "progressive":
            async for event in self._progressive_execute(input_data, config):
                yield event
        
        elif mode == "batch":
            result = await self.graph.ainvoke(input_data, config)
            yield {"type": "completed", "data": result}
        
        else:
            raise ValueError(f"Unknown mode: {mode}")
    
    async def _streaming_execute(self, input_data: dict, config: dict):
        """流式执行(逐步骤输出)"""
        async for step in self.graph.astream(input_data, config):
            for node, state in step.items():
                yield {
                    "type": "node_output",
                    "data": {
                        "node": node,
                        "state": state,
                        "timestamp": datetime.now().isoformat(),
                    },
                }
    
    async def _progressive_execute(self, input_data: dict, config: dict):
        """渐进式执行(节点完成后输出)"""
        async for step in self.graph.astream(input_data, config):
            for node, state in step.items():
                # 检查节点是否完成
                if self._is_node_complete(state):
                    yield {
                        "type": "node_complete",
                        "data": {
                            "node": node,
                            "state": state,
                            "timestamp": datetime.now().isoformat(),
                        },
                    }
    
    def _is_node_complete(self, state: dict) -> bool:
        """判断节点是否完成"""
        # 检查是否有输出或错误标记
        return "output" in state or "error" in state

3. 前端集成示例

完整的前端集成代码
// frontend/src/services/streamingService.js
class StreamingService {
    constructor(baseUrl = '/api/v1') {
        this.baseUrl = baseUrl;
        this.eventSource = null;
        this.listeners = {};
    }

    connect(message, threadId = null, options = {}) {
        // 如果已有连接,先关闭
        this.disconnect();

        // 构建 URL
        const params = new URLSearchParams({
            message: message,
            stream_mode: options.mode || 'full',
        });

        if (threadId) {
            params.set('thread_id', threadId);
        }

        const url = `${this.baseUrl}/chat/stream?${params.toString()}`;

        // 创建 EventSource
        this.eventSource = new EventSource(url, {
            withCredentials: true,
        });

        // 设置事件监听器
        this.eventSource.addEventListener('started', (event) => {
            this._emit('started', JSON.parse(event.data));
        });

        this.eventSource.addEventListener('node_start', (event) => {
            this._emit('nodeStart', JSON.parse(event.data));
        });

        this.eventSource.addEventListener('node_output', (event) => {
            this._emit('nodeOutput', JSON.parse(event.data));
        });

        this.eventSource.addEventListener('subgraph_start', (event) => {
            this._emit('subgraphStart', JSON.parse(event.data));
        });

        this.eventSource.addEventListener('subgraph_node_output', (event) => {
            this._emit('subgraphNodeOutput', JSON.parse(event.data));
        });

        this.eventSource.addEventListener('completed', (event) => {
            this._emit('completed', JSON.parse(event.data));
            this.disconnect();
        });

        this.eventSource.addEventListener('error', (event) => {
            this._emit('error', event);
            this.disconnect();
        });

        // 心跳处理
        this.eventSource.addEventListener('heartbeat', () => {
            // 保持连接活跃
        });
    }

    disconnect() {
        if (this.eventSource) {
            this.eventSource.close();
            this.eventSource = null;
        }
    }

    on(eventName, callback) {
        if (!this.listeners[eventName]) {
            this.listeners[eventName] = [];
        }
        this.listeners[eventName].push(callback);
    }

    off(eventName, callback) {
        if (this.listeners[eventName]) {
            this.listeners[eventName] = this.listeners[eventName].filter(
                (cb) => cb !== callback
            );
        }
    }

    _emit(eventName, data) {
        if (this.listeners[eventName]) {
            this.listeners[eventName].forEach((callback) => callback(data));
        }
    }
}

// 使用示例
const streamingService = new StreamingService();

streamingService.on('started', (data) => {
    console.log('会话开始:', data.thread_id);
});

streamingService.on('nodeOutput', (data) => {
    console.log('节点输出:', data.node, data.state);
    // 更新UI显示进度
    updateProgress(data.node, data.state);
});

streamingService.on('completed', (data) => {
    console.log('会话完成:', data.thread_id);
    showCompletionMessage();
});

streamingService.on('error', (error) => {
    console.error('流式连接错误:', error);
    showErrorMessage('连接中断,请重试');
});

// 发起请求
streamingService.connect('帮我写一篇小红书笔记', userThreadId);

五、小红书工作流扩展

1. 内容生成优化

智能内容生成策略
class XHSContentGenerator:
    """小红书内容生成器(优化版)"""
    
    def __init__(self):
        self._llm_client = None
    
    async def generate(
        self,
        topic: str,
        content_type: str = "normal",
        style: str = "professional",
        length: str = "medium",
    ) -> dict:
        """生成小红书内容
        
        Args:
            topic: 主题
            content_type: 内容类型(normal/lifestyle/emotional)
            style: 风格(professional/casual/humorous)
            length: 长度(short/medium/long)
        
        Returns:
            包含标题、正文、标签的内容字典
        """
        # 根据参数选择生成策略
        if content_type == "lifestyle":
            return await self._generate_lifestyle_content(topic, style)
        else:
            return await self._generate_normal_content(topic, style, length)
    
    async def _generate_normal_content(
        self,
        topic: str,
        style: str,
        length: str,
    ) -> dict:
        """生成普通内容"""
        length_map = {
            "short": 300,
            "medium": 600,
            "long": 1000,
        }
        
        prompt = f"""请帮我写一篇关于「{topic}」的小红书笔记。

要求:
- 风格:{self._style_to_description(style)}
- 字数:约{length_map[length]}字
- 包含吸引人的标题
- 使用合适的标签(3-5个)
- 符合小红书平台风格

输出格式:
标题:[标题内容]
正文:[正文内容]
标签:#[标签1] #[标签2] #[标签3]"""
        
        result = await self._call_llm(prompt)
        return self._parse_content(result)
    
    async def _generate_lifestyle_content(self, topic: str, style: str) -> dict:
        """生成生活化内容"""
        prompt = f"""请以第一人称视角写一篇关于「{topic}」的生活化小红书笔记。

要求:
- 风格:{self._style_to_description(style)}
- 语气亲切自然,像和朋友聊天
- 可以加入个人感受和情绪
- 避免过于正式的语言
- 包含表情符号增加亲切感
- 使用合适的标签(3-5个)

输出格式:
标题:[标题内容]
正文:[正文内容]
标签:#[标签1] #[标签2] #[标签3]"""
        
        result = await self._call_llm(prompt)
        return self._parse_content(result)
    
    def _style_to_description(self, style: str) -> str:
        """风格描述映射"""
        styles = {
            "professional": "专业、正式,适合知识分享",
            "casual": "轻松、随意,适合日常分享",
            "humorous": "幽默、有趣,带点调侃",
            "emotional": "感性、温暖,触动人心",
        }
        return styles.get(style, "自然、适中")
    
    def _parse_content(self, llm_result: str) -> dict:
        """解析LLM返回的内容"""
        lines = llm_result.strip().split("\n")
        
        result = {
            "title": "",
            "content": "",
            "tags": [],
        }
        
        current_section = None
        
        for line in lines:
            if line.startswith("标题:"):
                result["title"] = line.replace("标题:", "").strip()
                current_section = None
            elif line.startswith("正文:"):
                result["content"] = line.replace("正文:", "").strip()
                current_section = "content"
            elif line.startswith("标签:"):
                tags_str = line.replace("标签:", "").strip()
                result["tags"] = [t.strip("#") for t in tags_str.split()]
                current_section = None
            elif current_section == "content":
                result["content"] += "\n" + line.strip()
        
        return result

2. 发布流程完善

增强的发布流程
class XHSPublishFlow:
    """小红书发布流程(完善版)"""
    
    def __init__(self):
        self._platform_manager = PlatformManager()
        self._content_generator = XHSContentGenerator()
    
    async def execute(
        self,
        topic: str,
        platforms: list = None,
        should_publish: bool = True,
        content_options: Optional[dict] = None,
    ) -> dict:
        """执行完整的发布流程
        
        Args:
            topic: 内容主题
            platforms: 目标平台列表
            should_publish: 是否发布
            content_options: 内容生成选项
        
        Returns:
            执行结果
        """
        result = {
            "success": False,
            "steps": [],
            "output": {},
        }
        
        try:
            # Step 1: 生成内容
            step_result = await self._step_generate_content(topic, content_options)
            result["steps"].append({"step": "content_generation", "status": "completed"})
            result["output"]["content"] = step_result
            
            if not step_result.get("success"):
                result["error"] = "内容生成失败"
                return result
            
            # Step 2: 生成图片(如果需要)
            if content_options and content_options.get("need_images", False):
                step_result = await self._step_generate_images(topic)
                result["steps"].append({"step": "image_generation", "status": "completed"})
                result["output"]["images"] = step_result
            
            # Step 3: 发布(如果需要)
            if should_publish:
                step_result = await self._step_publish(
                    step_result["content"],
                    step_result.get("images", []),
                    platforms or ["xiaohongshu"],
                )
                result["steps"].append({"step": "publishing", "status": "completed"})
                result["output"]["publish"] = step_result
            
            result["success"] = True
            
        except Exception as e:
            result["success"] = False
            result["error"] = str(e)
            result["steps"].append({"step": "error", "status": "failed", "error": str(e)})
        
        return result
    
    async def _step_generate_content(self, topic: str, options: Optional[dict]) -> dict:
        """生成内容步骤"""
        options = options or {}
        
        result = await self._content_generator.generate(
            topic=topic,
            content_type=options.get("content_type", "normal"),
            style=options.get("style", "professional"),
            length=options.get("length", "medium"),
        )
        
        return {
            "success": True,
            "content": result,
        }
    
    async def _step_generate_images(self, topic: str) -> dict:
        """生成图片步骤"""
        # 调用图片生成服务
        # ...
        
        return {
            "success": True,
            "images": [],  # 图片URL列表
        }
    
    async def _step_publish(
        self,
        content: dict,
        images: list,
        platforms: list,
    ) -> dict:
        """发布步骤"""
        results = await self._platform_manager.publish_to_platforms(
            content=content,
            platforms=platforms,
            images=images,
        )
        
        # 检查是否全部成功
        all_success = all(r["success"] for r in results.values())
        
        return {
            "success": all_success,
            "results": results,
        }

六、遇到的问题与解决方案

问题1:图执行性能优化

现象:图执行时间过长,特别是在复杂工作流中

原因分析

  1. 每次执行都重新编译图
  2. 状态检查点操作频繁
  3. 缺乏缓存机制

解决方案

class CachedGraphExecutor:
    """带缓存的图执行器"""
    
    def __init__(self, graph_builder: GraphBuilder):
        self.graph_builder = graph_builder
        self._compiled_graph = None
        self._compile_lock = asyncio.Lock()
    
    async def get_graph(self):
        """获取编译后的图(带缓存)"""
        async with self._compile_lock:
            if self._compiled_graph is None:
                self._compiled_graph = await self.graph_builder.compile()
        
        return self._compiled_graph
    
    async def invalidate_cache(self):
        """使缓存失效(配置变更时调用)"""
        async with self._compile_lock:
            self._compiled_graph = None

问题2:并发会话管理

现象:多用户并发时会话状态混淆

原因分析

  1. 状态存储没有正确隔离
  2. 缺少会话级别的锁机制

解决方案

class SessionManager:
    """会话管理器(带锁机制)"""
    
    def __init__(self):
        self._session_locks = {}  # {thread_id: asyncio.Lock}
    
    async def acquire_lock(self, thread_id: str):
        """获取会话锁"""
        if thread_id not in self._session_locks:
            self._session_locks[thread_id] = asyncio.Lock()
        
        return self._session_locks[thread_id]
    
    async def with_session_lock(self, thread_id: str, func: callable, *args, **kwargs):
        """在会话锁保护下执行函数"""
        async with self.acquire_lock(thread_id):
            return await func(*args, **kwargs)

问题3:流式输出数据量过大

现象:流式输出包含大量重复数据,网络传输压力大

原因分析

  1. 每次状态更新都发送完整状态
  2. 缺少数据压缩机制

解决方案

class DeltaEncoder:
    """增量编码器"""
    
    def __init__(self):
        self._previous_state = {}
    
    def encode(self, current_state: dict) -> dict:
        """计算增量并更新状态"""
        delta = {}
        
        for key, value in current_state.items():
            if key not in self._previous_state or self._previous_state[key] != value:
                delta[key] = value
        
        self._previous_state = current_state.copy()
        
        return delta
    
    def reset(self):
        """重置状态(新会话开始时调用)"""
        self._previous_state = {}

# 使用示例
encoder = DeltaEncoder()

async def stream_with_delta(executor, input_data):
    encoder.reset()
    
    async for state in executor.stream(input_data):
        delta = encoder.encode(state)
        if delta:  # 只有有变化时才发送
            yield delta

问题4:MCP 服务动态注册

现象:MCP 服务需要在运行时动态注册

原因分析

  1. 服务配置可能在运行时变更
  2. 需要支持热更新

解决方案

class MCPRegistry:
    """MCP服务注册表(支持动态注册)"""
    
    def __init__(self):
        self._services = {}
        self._service_factories = {}
        self._registry_lock = asyncio.Lock()
    
    async def register_service(
        self,
        name: str,
        factory: callable,
        config: Optional[dict] = None,
    ):
        """注册服务(支持动态注册)"""
        async with self._registry_lock:
            self._service_factories[name] = {
                "factory": factory,
                "config": config or {},
            }
            
            # 如果服务已存在,重新创建
            if name in self._services:
                await self._services[name].close()
                del self._services[name]
    
    async def get_client(self, name: str):
        """获取服务客户端(懒加载)"""
        async with self._registry_lock:
            if name not in self._services:
                if name not in self._service_factories:
                    raise ValueError(f"Service {name} not registered")
                
                factory_info = self._service_factories[name]
                self._services[name] = await factory_info["factory"](**factory_info["config"])
            
            return self._services[name]
    
    async def reload_service(self, name: str):
        """重新加载服务"""
        async with self._registry_lock:
            if name in self._services:
                await self._services[name].close()
                del self._services[name]

七、测试与验证

集成测试

# tests/test_integration.py
@pytest.mark.asyncio
async def test_full_workflow():
    """测试完整工作流"""
    from ai_social_scheduler.subgraphs import XHSWorkflowSubgraph
    
    workflow = XHSWorkflowSubgraph()
    await workflow.initialize()
    
    result = await workflow.invoke({
        "description": "咖啡制作教程",
        "image_count": 3,
        "should_publish": False,
    })
    
    assert result.get("success") == True
    assert "content_result" in result
    assert "image_result" in result

@pytest.mark.asyncio
async def test_multi_platform_publish():
    """测试多平台发布"""
    from ai_social_scheduler.platform import PlatformManager
    
    manager = PlatformManager()
    
    content = {
        "title": "测试标题",
        "content": "测试内容",
        "tags": ["测试", "自动化"],
    }
    
    result = await manager.publish_to_platforms(
        content=content,
        platforms=["xiaohongshu"],
    )
    
    assert "xiaohongshu" in result

性能测试

# tests/test_performance.py
@pytest.mark.asyncio
async def test_execution_time():
    """测试执行时间"""
    import time
    
    workflow = XHSWorkflowSubgraph()
    await workflow.initialize()
    
    start_time = time.time()
    
    await workflow.invoke({
        "description": "测试内容",
        "image_count": 1,
        "should_publish": False,
    })
    
    elapsed = time.time() - start_time
    assert elapsed < 30  # 应在30秒内完成

@pytest.mark.asyncio
async def test_concurrent_execution():
    """测试并发执行"""
    import asyncio
    
    workflow = XHSWorkflowSubgraph()
    await workflow.initialize()
    
    tasks = [
        workflow.invoke({
            "description": f"测试任务 {i}",
            "image_count": 1,
            "should_publish": False,
        })
        for i in range(5)
    ]
    
    start_time = time.time()
    results = await asyncio.gather(*tasks)
    elapsed = time.time() - start_time
    
    # 检查所有任务都成功
    for result in results:
        assert result.get("success") == True
    
    # 并发执行应比串行快
    assert elapsed < 60  # 5个任务应在60秒内完成

压力测试

# tests/test_stress.py
@pytest.mark.asyncio
async def test_high_concurrency():
    """高并发测试"""
    import asyncio
    import time
    
    workflow = XHSWorkflowSubgraph()
    await workflow.initialize()
    
    # 模拟20个并发请求
    tasks = []
    for i in range(20):
        tasks.append(
            asyncio.create_task(
                workflow.invoke({
                    "description": f"压力测试任务 {i}",
                    "image_count": 1,
                    "should_publish": False,
                })
            )
        )
    
    start_time = time.time()
    
    # 设置总体超时
    try:
        results = await asyncio.wait_for(
            asyncio.gather(*tasks),
            timeout=120,  # 2分钟超时
        )
    except asyncio.TimeoutError:
        pytest.fail("高并发测试超时")
    
    elapsed = time.time() - start_time
    
    # 检查成功比例
    success_count = sum(1 for r in results if r.get("success"))
    success_rate = success_count / len(results)
    
    assert success_rate >= 0.9  # 至少90%成功率
    
    print(f"高并发测试完成: {len(results)} 个任务, 耗时 {elapsed:.2f} 秒, 成功率 {success_rate:.2%}")

八、开发进度与下一步计划

本周完成的工作

模块 功能
工作流引擎 图执行器重构、状态管理增强
智能路由 新增路由策略、意图分析器增强
流式 API SSE协议优化、心跳机制、多模式支持
小红书工作流 多平台支持、内容生成优化
错误处理 统一错误处理、重试机制
集成测试 完整工作流测试、性能测试

下一步计划

  1. 优化系统性能

    • 实现图执行缓存
    • 优化数据库查询
    • 添加性能监控
  2. 开发管理界面

    • Web 管理后台
    • 任务监控面板
    • 数据分析报表
  3. 完善部署方案

    • Docker 容器化部署
    • Kubernetes 支持
    • 自动化部署脚本

总结

本周完成了 重点工作包括:

  1. 工作流引擎优化:重构了图执行器,增强了状态管理,实现了统一的错误处理和重试机制
  2. 智能路由系统扩展:新增了多种路由策略,增强了意图分析器的多模型支持能力
  3. 流式 API 完善:优化了 SSE 协议实现,添加了心跳机制和多模式支持
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐