Spring AI 持久化对话完整指南:工具调用与自定义实现

前言

在开发 AI 应用时,对话持久化是一个非常重要的功能。它能够让应用记住用户的对话历史,提供更连贯的交互体验。Spring AI 提供了强大的对话记忆管理功能,本文将详细介绍如何实现对话持久化,包括工具调用的对话持久化以及自定义持久化方案。

一、Spring AI 对话持久化概述

1.1 为什么需要对话持久化

  • 上下文连贯性:让 AI 能够理解前面对话的内容
  • 用户体验优化:避免用户重复提供相同信息
  • 多轮对话支持:实现复杂的交互流程
  • 工具调用追踪:记录 AI 调用工具的历史

1.2 Spring AI 的记忆管理架构

Spring AI 提供了 ChatMemory 接口来管理对话记忆,支持多种实现方式:

  • InMemoryChatMemory:内存存储,适合简单场景
  • RedisChatMemory:Redis 存储,支持分布式
  • CustomChatMemory:自定义实现,支持任意存储介质

二、基础对话持久化实现

2.1 添加依赖

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
    <version>1.0.0-M4</version>
</dependency>
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-redis-store-spring-boot-starter</artifactId>
    <version>1.0.0-M4</version>
</dependency>

2.2 内存存储实现

@Configuration
public class ChatMemoryConfig {
    
    @Bean
    public ChatMemory chatMemory() {
        return new InMemoryChatMemory();
    }
}

2.3 使用示例

@Service
public class ChatService {
    
    private final ChatClient chatClient;
    
    public ChatService(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
        this.chatClient = chatClientBuilder
            .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
            .build();
    }
    
    public String chat(String userId, String message) {
        return chatClient.prompt()
            .user(message)
            .advisors(a -> a
                .param(CHAT_MEMORY_CONVERSATION_ID_KEY, userId)
                .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
            .call()
            .content();
    }
}

三、Redis 持久化实现

3.1 配置 Redis 存储

@Configuration
public class RedisChatMemoryConfig {
    
    @Bean
    public ChatMemory redisChatMemory(RedisConnectionFactory redisConnectionFactory) {
        return new RedisChatMemory(redisConnectionFactory);
    }
}

3.2 Redis 配置

spring.redis.host=localhost
spring.redis.port=6379
spring.ai.chat.memory.redis.ttl=3600
spring.ai.chat.memory.redis.key-prefix=chat:

3.3 Redis 数据结构

Spring AI 在 Redis 中使用 Hash 结构存储对话历史:

Key: chat:conversation:{userId}
Fields:
  - messages: JSON 格式的消息列表
  - timestamp: 最后更新时间

四、工具调用的对话持久化

4.1 工具调用记录的重要性

工具调用(Tool Calls)是 AI 与外部系统交互的重要方式,记录这些调用对于:

  • 调试和问题排查
  • 理解 AI 的决策过程
  • 审计和合规要求
  • 优化工具使用效率

4.2 工具调用持久化配置

@Configuration
public class ToolCallMemoryConfig {
    
    @Bean
    public ChatClient chatClient(ChatClient.Builder builder, ChatMemory chatMemory) {
        return builder
            .defaultAdvisors(
                new MessageChatMemoryAdvisor(chatMemory),
                new ToolCallMemoryAdvisor()
            )
            .build();
    }
}

4.3 工具调用记录结构

工具调用记录包含以下信息:

public class ToolCallRecord {
    private String toolName;          // 工具名称
    private String toolId;            // 工具ID
    private Map<String, Object> arguments; // 调用参数
    private Object result;            // 返回结果
    private long executionTime;       // 执行时间
    private boolean success;          // 是否成功
    private String errorMessage;      // 错误信息
}

4.4 自定义工具调用记录器

@Component
public class CustomToolCallLogger implements ToolCallListener {
    
    private final ToolCallRepository repository;
    
    @Override
    public void beforeToolCall(ToolCallRequest request) {
        log.info("准备调用工具: {}, 参数: {}", 
            request.getToolName(), request.getArguments());
    }
    
    @Override
    public void afterToolCall(ToolCallResponse response) {
        ToolCallRecord record = new ToolCallRecord();
        record.setToolName(response.getToolName());
        record.setResult(response.getResult());
        record.setExecutionTime(response.getExecutionTime());
        record.setSuccess(response.isSuccess());
        
        repository.save(record);
        
        log.info("工具调用完成: {}, 耗时: {}ms", 
            response.getToolName(), response.getExecutionTime());
    }
}

4.5 持久化工具调用到数据库

创建数据表:

CREATE TABLE tool_calls (
    id BIGINT PRIMARY KEY AUTO_INCREMENT,
    conversation_id VARCHAR(255) NOT NULL,
    tool_name VARCHAR(100) NOT NULL,
    tool_id VARCHAR(100),
    arguments TEXT,
    result TEXT,
    execution_time BIGINT,
    success BOOLEAN,
    error_message TEXT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    INDEX idx_conversation (conversation_id),
    INDEX idx_tool_name (tool_name)
);

Repository 实现:

@Repository
public interface ToolCallRepository extends JpaRepository<ToolCallRecord, Long> {
    
    List<ToolCallRecord> findByConversationIdOrderByCreatedAt(String conversationId);
    
    @Query("SELECT t.toolName, COUNT(*) FROM ToolCallRecord t " +
           "WHERE t.createdAt > :startDate GROUP BY t.toolName")
    List<Object[]> findToolUsageStats(@Param("startDate") LocalDateTime startDate);
}

五、自定义持久化实现

5.1 实现自定义 ChatMemory

public class CustomDatabaseChatMemory implements ChatMemory {
    
    private final ConversationMessageRepository repository;
    private final int maxMessages;
    
    public CustomDatabaseChatMemory(ConversationMessageRepository repository, 
                                     int maxMessages) {
        this.repository = repository;
        this.maxMessages = maxMessages;
    }
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        for (Message message : messages) {
            ConversationMessage cm = new ConversationMessage();
            cm.setConversationId(conversationId);
            cm.setMessageType(message.getMessageType().name());
            cm.setContent(message.getContent());
            cm.setMetadata(extractMetadata(message));
            repository.save(cm);
        }
        
        // 限制消息数量
        trimMessages(conversationId);
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ConversationMessage> records = repository
            .findByConversationIdOrderByCreatedAtDesc(conversationId, lastN);
        
        return records.stream()
            .sorted(Comparator.comparing(ConversationMessage::getCreatedAt))
            .map(this::toMessage)
            .collect(Collectors.toList());
    }
    
    @Override
    public void clear(String conversationId) {
        repository.deleteByConversationId(conversationId);
    }
    
    private void trimMessages(String conversationId) {
        long count = repository.countByConversationId(conversationId);
        if (count > maxMessages) {
            List<ConversationMessage> oldMessages = repository
                .findOldestMessages(conversationId, (int)(count - maxMessages));
            repository.deleteAll(oldMessages);
        }
    }
}

5.2 数据库表设计

CREATE TABLE conversation_messages (
    id BIGINT PRIMARY KEY AUTO_INCREMENT,
    conversation_id VARCHAR(255) NOT NULL,
    message_type VARCHAR(50) NOT NULL,
    content TEXT NOT NULL,
    metadata TEXT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    INDEX idx_conversation_created (conversation_id, created_at)
);

5.3 多租户持久化实现

public class MultiTenantChatMemory implements ChatMemory {
    
    private final ChatMemory delegate;
    private final TenantProvider tenantProvider;
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        String tenantId = tenantProvider.getCurrentTenantId();
        String scopedConversationId = tenantId + ":" + conversationId;
        delegate.add(scopedConversationId, messages);
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        String tenantId = tenantProvider.getCurrentTenantId();
        String scopedConversationId = tenantId + ":" + conversationId;
        return delegate.get(scopedConversationId, lastN);
    }
}

5.4 分层存储策略

public class TieredChatMemory implements ChatMemory {
    
    private final ChatMemory hotStorage;  // Redis - 热数据
    private final ChatMemory coldStorage; // Database - 冷数据
    private final int hotSizeLimit;
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        // 先从热存储获取
        List<Message> messages = hotStorage.get(conversationId, lastN);
        
        if (messages.size() < lastN) {
            // 从冷存储补充
            int needed = lastN - messages.size();
            List<Message> coldMessages = coldStorage.get(conversationId, needed);
            messages.addAll(0, coldMessages);
        }
        
        return messages;
    }
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        hotStorage.add(conversationId, messages);
        coldStorage.add(conversationId, messages);
        
        // 定期清理热存储
        if (hotStorage.get(conversationId, Integer.MAX_VALUE).size() > hotSizeLimit) {
            trimHotStorage(conversationId);
        }
    }
}

六、高级特性

6.1 对话总结与压缩

@Component
public class ConversationSummarizer {
    
    private final ChatClient summarizerClient;
    
    public String summarizeConversation(String conversationId, List<Message> messages) {
        String prompt = """
            请总结以下对话的核心内容,保持重要信息:
            %s
            
            总结要求:
            1. 保留关键决策和结论
            2. 记录重要的工具调用结果
            3. 简明扼要,不超过200字
            """.formatted(formatMessages(messages));
        
        return summarizerClient.prompt()
            .user(prompt)
            .call()
            .content();
    }
}

6.2 对话索引与搜索

@Service
public class ConversationSearchService {
    
    private final ElasticsearchTemplate elasticsearchTemplate;
    
    public void indexConversation(String conversationId, List<Message> messages) {
        ConversationDocument doc = new ConversationDocument();
        doc.setConversationId(conversationId);
        doc.setContent(messages.stream()
            .map(Message::getContent)
            .collect(Collectors.joining("\n")));
        doc.setTimestamp(Instant.now());
        
        elasticsearchTemplate.save(doc);
    }
    
    public List<String> searchConversations(String keyword, int topK) {
        NativeSearchQuery query = NativeSearchQuery.builder()
            .withQuery(QueryBuilders.matchQuery("content", keyword))
            .withPageable(PageRequest.of(0, topK))
            .build();
        
        return elasticsearchTemplate.search(query, ConversationDocument.class)
            .stream()
            .map(hit -> hit.getContent().getConversationId())
            .collect(Collectors.toList());
    }
}

6.3 对话导出与导入

@Service
public class ConversationExportService {
    
    public String exportConversation(String conversationId, String format) {
        List<Message> messages = chatMemory.get(conversationId, Integer.MAX_VALUE);
        
        return switch (format.toLowerCase()) {
            case "json" -> exportToJson(messages);
            case "markdown" -> exportToMarkdown(messages);
            case "txt" -> exportToText(messages);
            default -> throw new IllegalArgumentException("Unsupported format: " + format);
        };
    }
    
    private String exportToMarkdown(List<Message> messages) {
        StringBuilder sb = new StringBuilder();
        sb.append("# 对话记录\n\n");
        
        for (Message msg : messages) {
            String role = msg.getMessageType() == MessageType.USER ? "用户" : "AI";
            sb.append("## %s\n\n%s\n\n", role, msg.getContent());
        }
        
        return sb.toString();
    }
    
    public void importConversation(String conversationId, String jsonData) {
        List<Message> messages = parseMessagesFromJson(jsonData);
        chatMemory.add(conversationId, messages);
    }
}

七、最佳实践

7.1 性能优化

  • 批量操作:使用批量插入减少数据库调用
  • 异步持久化:使用消息队列异步处理持久化
  • 缓存策略:对热点对话使用内存缓存
  • 分页加载:避免一次性加载过多历史消息

7.2 数据安全

  • 敏感信息过滤:持久化前过滤敏感数据
  • 数据加密:对存储的对话内容加密
  • 访问控制:实现基于角色的数据访问控制
  • 数据脱敏:导出时对敏感信息脱敏处理

7.3 监控与告警

@Component
public class ChatMemoryMonitor {
    
    @Scheduled(fixedRate = 60000)
    public void monitorMemoryUsage() {
        long totalMessages = repository.count();
        long activeConversations = repository.countActiveConversations();
        
        if (totalMessages > WARNING_THRESHOLD) {
            alertService.sendAlert("对话消息数量超过阈值: " + totalMessages);
        }
        
        metrics.gauge("chat.memory.total.messages", totalMessages);
        metrics.gauge("chat.memory.active.conversations", activeConversations);
    }
}

八、总结

Spring AI 提供了灵活而强大的对话持久化机制:

  1. 多种存储方案:支持内存、Redis、数据库等多种存储方式
  2. 工具调用追踪:完整记录 AI 的工具调用历史
  3. 高度可定制:可以轻松实现自定义的持久化策略
  4. 企业级特性:支持多租户、分层存储、数据安全等企业需求

通过合理使用这些功能,可以构建出功能完善、性能优秀的 AI 对话应用。

参考资源

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐