Spring AI 1.0.4 版本持久化对话完全指南:从基础到自定义

前言

随着人工智能应用的普及,对话式 AI 系统已经成为许多应用的核心功能。Spring AI 作为 Spring 生态中的人工智能框架,在 1.0.4 版本中对对话持久化功能进行了重要增强。本文将深入探讨 Spring AI 1.0.4 的对话持久化机制,包括基础配置、工具调用的对话持久化,以及如何实现自定义持久化方案。

一、Spring AI 对话持久化概述

1.1 什么是对话持久化

对话持久化是指将用户与 AI 的交互历史存储到持久化存储中,使得:

  • 可以恢复之前的对话上下文
  • 支持多轮对话的连续性
  • 便于对话历史管理和分析
  • 支持多设备同步

1.2 Spring AI 1.0.4 的持久化架构

┌─────────────┐
│   User      │
└──────┬──────┘
       │
       ▼
┌─────────────┐     ┌─────────────┐
│   Chat      │────▶│  ChatMemory │
│  Interface  │     │   Service   │
└──────┬──────┘     └──────┬──────┘
       │                  │
       │              ┌───▼───┐
       │              │ Store │
       │              └───┬───┘
       ▼                  │
┌─────────────┐     ┌─────▼─────┐
│    AI       │     │  Database │
│  Service    │     │  / Redis  │
└─────────────┘     └───────────┘

二、基础对话持久化配置

2.1 依赖配置

pom.xml 中添加必要的依赖:

<dependencies>
    <!-- Spring AI 核心依赖 -->
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
        <version>1.0.4</version>
    </dependency>
    
    <!-- JPA 持久化 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-jpa</artifactId>
    </dependency>
    
    <!-- H2 数据库(开发环境) -->
    <dependency>
        <groupId>com.h2database</groupId>
        <artifactId>h2</artifactId>
        <scope>runtime</scope>
    </dependency>
    
    <!-- MySQL 驱动(生产环境) -->
    <dependency>
        <groupId>mysql</groupId>
        <artifactId>mysql-connector-java</artifactId>
    </dependency>
</dependencies>

2.2 application.yml 配置

spring:
  ai:
    openai:
      api-key: ${OPENAI_API_KEY}
      chat:
        options:
          model: gpt-4
          temperature: 0.7
  
  datasource:
    url: jdbc:h2:mem:testdb
    driver-class-name: org.h2.Driver
    username: sa
    password: 
  
  jpa:
    hibernate:
      ddl-auto: update
    show-sql: true
    properties:
      hibernate:
        format_sql: true
  
  h2:
    console:
      enabled: true

2.3 基础实体类设计

import jakarta.persistence.*;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;

@Entity
@Table(name = "chat_conversation")
public class ChatConversation {
    
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    
    @Column(nullable = false, unique = true)
    private String conversationId;
    
    @Column(nullable = false)
    private String userId;
    
    @Column(length = 500)
    private String title;
    
    @Column(nullable = false)
    private LocalDateTime createdAt;
    
    @Column(nullable = false)
    private LocalDateTime updatedAt;
    
    @OneToMany(mappedBy = "conversation", cascade = CascadeType.ALL, orphanRemoval = true)
    private List<ChatMessage> messages = new ArrayList<>();
    
    // 构造函数、getter、setter
    public ChatConversation() {
        this.createdAt = LocalDateTime.now();
        this.updatedAt = LocalDateTime.now();
    }
    
    public void addMessage(ChatMessage message) {
        messages.add(message);
        message.setConversation(this);
        this.updatedAt = LocalDateTime.now();
    }
    
    // getter 和 setter 省略...
}

@Entity
@Table(name = "chat_message")
public class ChatMessage {
    
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    
    @ManyToOne(fetch = FetchType.LAZY)
    @JoinColumn(name = "conversation_id", nullable = false)
    private ChatConversation conversation;
    
    @Enumerated(EnumType.STRING)
    @Column(nullable = false)
    private MessageType type;
    
    @Column(nullable = false, columnDefinition = "TEXT")
    private String content;
    
    @Column(columnDefinition = "TEXT")
    private String toolCalls;
    
    @Column(columnDefinition = "TEXT")
    private String toolResponses;
    
    @Column(nullable = false)
    private LocalDateTime timestamp;
    
    private Integer tokens;
    
    // 构造函数、getter、setter
    public ChatMessage() {
        this.timestamp = LocalDateTime.now();
    }
    
    public enum MessageType {
        USER,
        ASSISTANT,
        SYSTEM,
        TOOL
    }
    
    // getter 和 setter 省略...
}

2.4 Repository 层

import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

import java.util.List;
import java.util.Optional;

@Repository
public interface ChatConversationRepository extends JpaRepository<ChatConversation, Long> {
    
    Optional<ChatConversation> findByConversationId(String conversationId);
    
    List<ChatConversation> findByUserIdOrderByUpdatedAtDesc(String userId);
    
    @Query("SELECT c FROM ChatConversation c WHERE c.userId = :userId " +
           "ORDER BY c.updatedAt DESC")
    List<ChatConversation> findRecentConversations(@Param("userId") String userId, 
                                                      Pageable pageable);
}

三、实现 ChatMemory 接口

3.1 自定义 ChatMemory 实现

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

@Service
public class PersistentChatMemory implements ChatMemory {
    
    private final ChatConversationRepository conversationRepository;
    private final ChatMessageRepository messageRepository;
    
    // 最大消息历史数量
    private static final int MAX_MESSAGES = 50;
    
    public PersistentChatMemory(ChatConversationRepository conversationRepository,
                                 ChatMessageRepository messageRepository) {
        this.conversationRepository = conversationRepository;
        this.messageRepository = messageRepository;
    }
    
    @Override
    @Transactional
    public void add(String conversationId, List<Message> messages) {
        ChatConversation conversation = conversationRepository
            .findByConversationId(conversationId)
            .orElseGet(() -> createNewConversation(conversationId));
        
        for (Message message : messages) {
            ChatMessage chatMessage = convertToChatMessage(message, conversation);
            conversation.addMessage(chatMessage);
        }
        
        conversationRepository.save(conversation);
    }
    
    @Override
    @Transactional(readOnly = true)
    public List<Message> get(String conversationId, int lastN) {
        ChatConversation conversation = conversationRepository
            .findByConversationId(conversationId)
            .orElse(null);
        
        if (conversation == null) {
            return List.of();
        }
        
        return conversation.getMessages().stream()
            .skip(Math.max(0, conversation.getMessages().size() - lastN))
            .map(this::convertToSpringAIMessage)
            .toList();
    }
    
    @Override
    @Transactional(readOnly = true)
    public List<Message> get(String conversationId) {
        return get(conversationId, MAX_MESSAGES);
    }
    
    @Override
    @Transactional
    public void clear(String conversationId) {
        ChatConversation conversation = conversationRepository
            .findByConversationId(conversationId)
            .orElse(null);
        
        if (conversation != null) {
            conversation.getMessages().clear();
            conversationRepository.save(conversation);
        }
    }
    
    private ChatConversation createNewConversation(String conversationId) {
        ChatConversation conversation = new ChatConversation();
        conversation.setConversationId(conversationId != null ? conversationId : UUID.randomUUID().toString());
        conversation.setUserId("default_user"); // 可以从上下文获取
        conversation.setTitle("新对话");
        return conversation;
    }
    
    private ChatMessage convertToChatMessage(Message message, ChatConversation conversation) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setConversation(conversation);
        chatMessage.setContent(message.getContent());
        chatMessage.setType(mapMessageType(message.getMessageType()));
        
        // 处理工具调用信息
        if (message.getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) {
            if (message.getMetadata() != null && message.getMetadata().containsKey("toolCalls")) {
                chatMessage.setToolCalls(
                    message.getMetadata().get("toolCalls").toString()
                );
            }
        }
        
        return chatMessage;
    }
    
    private ChatMessage.MessageType mapMessageType(org.springframework.ai.chat.messages.MessageType type) {
        return switch (type) {
            case USER -> ChatMessage.MessageType.USER;
            case ASSISTANT -> ChatMessage.MessageType.ASSISTANT;
            case SYSTEM -> ChatMessage.MessageType.SYSTEM;
            default -> ChatMessage.MessageType.ASSISTANT;
        };
    }
    
    private Message convertToSpringAIMessage(ChatMessage chatMessage) {
        return switch (chatMessage.getType()) {
            case USER -> new org.springframework.ai.chat.messages.UserMessage(chatMessage.getContent());
            case ASSISTANT -> {
                org.springframework.ai.chat.messages.AssistantMessage assistantMessage = 
                    new org.springframework.ai.chat.messages.AssistantMessage(chatMessage.getContent());
                // 恢复工具调用信息
                if (chatMessage.getToolCalls() != null) {
                    // 这里需要根据实际存储格式反序列化
                    // assistantMessage.setToolCalls(...);
                }
                yield assistantMessage;
            }
            case SYSTEM -> new org.springframework.ai.chat.messages.SystemMessage(chatMessage.getContent());
            case TOOL -> new org.springframework.ai.chat.messages.ToolResponseMessage(
                chatMessage.getContent(), 
                List.of() // 工具ID等
            );
        };
    }
}

3.2 配置 ChatService 使用自定义 ChatMemory

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class ChatConfig {
    
    @Bean
    public ChatClient chatClient(ChatClient.Builder builder, 
                                  ChatMemory chatMemory) {
        return builder
            .defaultSystem("你是一个有帮助的AI助手")
            .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
            .build();
    }
}

四、工具调用的对话持久化

4.1 工具调用扩展实体

@Entity
@Table(name = "tool_execution")
public class ToolExecution {
    
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    
    @ManyToOne(fetch = FetchType.LAZY)
    @JoinColumn(name = "chat_message_id", nullable = false)
    private ChatMessage chatMessage;
    
    @Column(nullable = false)
    private String toolName;
    
    @Column(columnDefinition = "TEXT")
    private String arguments;
    
    @Column
    private String result;
    
    @Column
    private Long executionTime; // 执行耗时(毫秒)
    
    @Enumerated(EnumType.STRING)
    @Column(nullable = false)
    private ExecutionStatus status;
    
    @Column(nullable = false)
    private LocalDateTime executedAt;
    
    @Column(length = 1000)
    private String errorMessage;
    
    public enum ExecutionStatus {
        SUCCESS,
        FAILED,
        TIMEOUT
    }
    
    public ToolExecution() {
        this.executedAt = LocalDateTime.now();
    }
    
    // getter 和 setter 省略...
}

4.2 工具调用监听器

import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackResponse;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;

import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.List;
import java.util.Map;

@Component
public class ToolExecutionListener {
    
    private final ToolExecutionRepository toolExecutionRepository;
    private final ChatMessageRepository chatMessageRepository;
    private final ObjectMapper objectMapper;
    
    public ToolExecutionListener(ToolExecutionRepository toolExecutionRepository,
                                  ChatMessageRepository chatMessageRepository,
                                  ObjectMapper objectMapper) {
        this.toolExecutionRepository = toolExecutionRepository;
        this.chatMessageRepository = chatMessageRepository;
        this.objectMapper = objectMapper;
    }
    
    @Transactional
    public void recordToolExecution(Long chatMessageId, 
                                     String toolName,
                                     Map<String, Object> arguments,
                                     Object result,
                                     long executionTime,
                                     ToolExecution.ExecutionStatus status,
                                     String errorMessage) {
        ChatMessage chatMessage = chatMessageRepository.findById(chatMessageId)
            .orElseThrow(() -> new IllegalArgumentException("Chat message not found"));
        
        ToolExecution execution = new ToolExecution();
        execution.setChatMessage(chatMessage);
        execution.setToolName(toolName);
        
        try {
            execution.setArguments(objectMapper.writeValueAsString(arguments));
            if (result != null) {
                execution.setResult(objectMapper.writeValueAsString(result));
            }
        } catch (Exception e) {
            execution.setArguments(arguments.toString());
            if (result != null) {
                execution.setResult(result.toString());
            }
        }
        
        execution.setExecutionTime(executionTime);
        execution.setStatus(status);
        execution.setErrorMessage(errorMessage);
        
        toolExecutionRepository.save(execution);
    }
    
    @Transactional
    public void recordAssistantMessageWithToolCalls(ChatMessage chatMessage, 
                                                     List<ToolCallback> toolCalls) {
        try {
            String toolCallsJson = objectMapper.writeValueAsString(toolCalls);
            chatMessage.setToolCalls(toolCallsJson);
            chatMessageRepository.save(chatMessage);
        } catch (Exception e) {
            chatMessage.setToolCalls(toolCalls.toString());
            chatMessageRepository.save(chatMessage);
        }
    }
}

4.3 完整的聊天服务实现

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;

@Service
public class ChatService {
    
    private final ChatClient chatClient;
    private final ChatConversationRepository conversationRepository;
    private final ToolExecutionListener toolExecutionListener;
    private final PersistentChatMemory chatMemory;
    
    public ChatService(ChatClient chatClient,
                        ChatConversationRepository conversationRepository,
                        ToolExecutionListener toolExecutionListener,
                        PersistentChatMemory chatMemory) {
        this.chatClient = chatClient;
        this.conversationRepository = conversationRepository;
        this.toolExecutionListener = toolExecutionListener;
        this.chatMemory = chatMemory;
    }
    
    @Transactional
    public ChatResponse chat(String conversationId, String userMessage, List<ToolCallback> tools) {
        // 确保 conversation 存在
        if (conversationId == null) {
            conversationId = UUID.randomUUID().toString();
            ChatConversation conversation = new ChatConversation();
            conversation.setConversationId(conversationId);
            conversation.setUserId(getCurrentUserId());
            conversation.setTitle(generateTitle(userMessage));
            conversationRepository.save(conversation);
        }
        
        // 构建请求
        UserMessage message = new UserMessage(userMessage);
        
        // 保存用户消息
        chatMemory.add(conversationId, List.of(message));
        
        long startTime = System.currentTimeMillis();
        
        // 调用 AI
        ChatResponse response;
        if (tools != null && !tools.isEmpty()) {
            response = chatClient.prompt()
                .user(userMessage)
                .advisors(new MessageChatMemoryAdvisor(chatMemory, conversationId, 50))
                .functions(tools)
                .call()
                .chatResponse();
        } else {
            response = chatClient.prompt()
                .user(userMessage)
                .advisors(new MessageChatMemoryAdvisor(chatMemory, conversationId, 50))
                .call()
                .chatResponse();
        }
        
        long executionTime = System.currentTimeMillis() - startTime;
        
        // 处理响应并保存
        processAndSaveResponse(conversationId, response, executionTime);
        
        return response;
    }
    
    private void processAndSaveResponse(String conversationId, 
                                         ChatResponse response,
                                         long executionTime) {
        Generation generation = response.getResult();
        AssistantMessage assistantMessage = generation.getOutput();
        
        // 保存助手消息
        chatMemory.add(conversationId, List.of(assistantMessage));
        
        // 如果有工具调用,记录工具执行
        if (assistantMessage.getToolCalls() != null && !assistantMessage.getToolCalls().isEmpty()) {
            // 查找刚保存的消息 ID
            ChatConversation conversation = conversationRepository
                .findByConversationId(conversationId)
                .orElseThrow();
            
            ChatMessage lastMessage = conversation.getMessages()
                .get(conversation.getMessages().size() - 1);
            
            // 记录工具调用
            toolExecutionListener.recordAssistantMessageWithToolCalls(
                lastMessage, 
                assistantMessage.getToolCalls()
            );
            
            // 记录每个工具的执行详情
            for (var toolCall : assistantMessage.getToolCalls()) {
                toolExecutionListener.recordToolExecution(
                    lastMessage.getId(),
                    toolCall.name(),
                    toolCall.arguments(),
                    null, // 结果在后续处理
                    executionTime,
                    ToolExecution.ExecutionStatus.SUCCESS,
                    null
                );
            }
        }
    }
    
    @Transactional(readOnly = true)
    public List<ChatMessage> getConversationHistory(String conversationId) {
        ChatConversation conversation = conversationRepository
            .findByConversationId(conversationId)
            .orElseThrow(() -> new IllegalArgumentException("Conversation not found"));
        
        return conversation.getMessages();
    }
    
    @Transactional
    public void deleteConversation(String conversationId) {
        conversationRepository.deleteByConversationId(conversationId);
        chatMemory.clear(conversationId);
    }
    
    private String getCurrentUserId() {
        // 从 Spring Security 上下文或其他方式获取当前用户
        return "user_" + UUID.randomUUID().toString();
    }
    
    private String generateTitle(String firstMessage) {
        // 简化版,实际可以用 AI 生成标题
        return firstMessage.length() > 20 ? 
               firstMessage.substring(0, 20) + "..." : firstMessage;
    }
}

五、自定义持久化实现

5.1 Redis 持久化方案

对于高并发场景,可以使用 Redis 作为缓存层,结合数据库实现混合持久化策略。

Redis 配置
spring:
  data:
    redis:
      host: localhost
      port: 6379
      password:
      database: 0
      timeout: 5000ms
      lettuce:
        pool:
          max-active: 8
          max-wait: -1ms
          max-idle: 8
          min-idle: 0
Redis ChatMemory 实现
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

@Service
public class RedisChatMemory implements ChatMemory {
    
    private final RedisTemplate<String, Object> redisTemplate;
    private final ObjectMapper objectMapper;
    
    private static final String CHAT_KEY_PREFIX = "chat:";
    private static final Duration DEFAULT_TTL = Duration.ofDays(7);
    private static final int MAX_MESSAGES = 100;
    
    public RedisChatMemory(RedisTemplate<String, Object> redisTemplate,
                           ObjectMapper objectMapper) {
        this.redisTemplate = redisTemplate;
        this.objectMapper = objectMapper;
    }
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        String key = CHAT_KEY_PREFIX + conversationId;
        
        for (Message message : messages) {
            redisTemplate.opsForList().rightPush(key, serializeMessage(message));
        }
        
        // 设置过期时间
        redisTemplate.expire(key, DEFAULT_TTL);
        
        // 限制消息数量
        Long size = redisTemplate.opsForList().size(key);
        if (size != null && size > MAX_MESSAGES) {
            redisTemplate.opsForList().trim(key, size - MAX_MESSAGES, -1);
        }
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        String key = CHAT_KEY_PREFIX + conversationId;
        
        Long size = redisTemplate.opsForList().size(key);
        if (size == null || size == 0) {
            return List.of();
        }
        
        int start = Math.max(0, (int)(size - lastN));
        List<Object> objects = redisTemplate.opsForList().range(key, start, -1);
        
        if (objects == null) {
            return List.of();
        }
        
        return objects.stream()
            .map(obj -> deserializeMessage((String) obj))
            .toList();
    }
    
    @Override
    public List<Message> get(String conversationId) {
        return get(conversationId, MAX_MESSAGES);
    }
    
    @Override
    public void clear(String conversationId) {
        String key = CHAT_KEY_PREFIX + conversationId;
        redisTemplate.delete(key);
    }
    
    private String serializeMessage(Message message) {
        try {
            Map<String, Object> messageMap = new HashMap<>();
            messageMap.put("type", message.getMessageType().name());
            messageMap.put("content", message.getContent());
            messageMap.put("metadata", message.getMetadata());
            return objectMapper.writeValueAsString(messageMap);
        } catch (Exception e) {
            throw new RuntimeException("Failed to serialize message", e);
        }
    }
    
    private Message deserializeMessage(String json) {
        try {
            Map<String, Object> messageMap = objectMapper.readValue(json, Map.class);
            String type = (String) messageMap.get("type");
            String content = (String) messageMap.get("content");
            
            return switch (type) {
                case "USER" -> new org.springframework.ai.chat.messages.UserMessage(content);
                case "ASSISTANT" -> new org.springframework.ai.chat.messages.AssistantMessage(content);
                case "SYSTEM" -> new org.springframework.ai.chat.messages.SystemMessage(content);
                default -> new org.springframework.ai.chat.messages.AssistantMessage(content);
            };
        } catch (Exception e) {
            throw new RuntimeException("Failed to deserialize message", e);
        }
    }
}

5.2 混合持久化策略(Redis + Database)

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.List;
import java.util.concurrent.CompletableFuture;

@Service
public class HybridChatMemory implements ChatMemory {
    
    private final RedisChatMemory redisMemory;
    private final PersistentChatMemory dbMemory;
    
    public HybridChatMemory(RedisChatMemory redisMemory,
                             PersistentChatMemory dbMemory) {
        this.redisMemory = redisMemory;
        this.dbMemory = dbMemory;
    }
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        // 先写入 Redis(快速)
        redisMemory.add(conversationId, messages);
        
        // 异步写入数据库(持久化)
        asyncAddToDatabase(conversationId, messages);
    }
    
    @Async
    @Transactional
    protected void asyncAddToDatabase(String conversationId, List<Message> messages) {
        dbMemory.add(conversationId, messages);
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        // 先从 Redis 获取
        List<Message> messages = redisMemory.get(conversationId, lastN);
        
        // 如果 Redis 为空,从数据库加载
        if (messages.isEmpty()) {
            messages = dbMemory.get(conversationId, lastN);
            // 回填 Redis
            if (!messages.isEmpty()) {
                redisMemory.add(conversationId, messages);
            }
        }
        
        return messages;
    }
    
    @Override
    public List<Message> get(String conversationId) {
        return get(conversationId, 50);
    }
    
    @Override
    public void clear(String conversationId) {
        redisMemory.clear(conversationId);
        dbMemory.clear(conversationId);
    }
}

5.3 自定义存储适配器

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;

/**
 * 自定义存储适配器接口
 */
public interface StorageAdapter {
    
    void store(String conversationId, List<Message> messages);
    
    List<Message> retrieve(String conversationId, int limit);
    
    void delete(String conversationId);
    
    boolean exists(String conversationId);
}

/**
 * MongoDB 存储适配器示例
 */
@Component
public class MongoStorageAdapter implements StorageAdapter {
    
    private final MongoTemplate mongoTemplate;
    
    public MongoStorageAdapter(MongoTemplate mongoTemplate) {
        this.mongoTemplate = mongoTemplate;
    }
    
    @Override
    public void store(String conversationId, List<Message> messages) {
        Query query = Query.query(Criteria.where("conversationId").is(conversationId));
        ConversationDocument doc = mongoTemplate.findOne(query, ConversationDocument.class);
        
        if (doc == null) {
            doc = new ConversationDocument();
            doc.setConversationId(conversationId);
        }
        
        List<MessageDocument> messageDocs = messages.stream()
            .map(this::convertToDocument)
            .toList();
        
        doc.getMessages().addAll(messageDocs);
        mongoTemplate.save(doc);
    }
    
    @Override
    public List<Message> retrieve(String conversationId, int limit) {
        Query query = Query.query(Criteria.where("conversationId").is(conversationId))
            .with(Sort.by(Sort.Direction.ASC, "timestamp"))
            .limit(limit);
        
        ConversationDocument doc = mongoTemplate.findOne(query, ConversationDocument.class);
        
        if (doc == null) {
            return List.of();
        }
        
        return doc.getMessages().stream()
            .map(this::convertToMessage)
            .toList();
    }
    
    @Override
    public void delete(String conversationId) {
        Query query = Query.query(Criteria.where("conversationId").is(conversationId));
        mongoTemplate.remove(query, ConversationDocument.class);
    }
    
    @Override
    public boolean exists(String conversationId) {
        Query query = Query.query(Criteria.where("conversationId").is(conversationId));
        return mongoTemplate.exists(query, ConversationDocument.class);
    }
    
    // 转换方法省略...
}

/**
 * 使用适配器的 ChatMemory 实现
 */
@Component
public class CustomStorageChatMemory implements ChatMemory {
    
    private final StorageAdapter storageAdapter;
    
    public CustomStorageChatMemory(StorageAdapter storageAdapter) {
        this.storageAdapter = storageAdapter;
    }
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        storageAdapter.store(conversationId, messages);
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        return storageAdapter.retrieve(conversationId, lastN);
    }
    
    @Override
    public List<Message> get(String conversationId) {
        return get(conversationId, 50);
    }
    
    @Override
    public void clear(String conversationId) {
        storageAdapter.delete(conversationId);
    }
}

六、REST API 接口

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/api/chat")
public class ChatController {
    
    private final ChatService chatService;
    
    public ChatController(ChatService chatService) {
        this.chatService = chatService;
    }
    
    @PostMapping("/send")
    public ResponseEntity<Map<String, Object>> sendMessage(
            @RequestBody ChatRequest request) {
        
        ChatResponse response = chatService.chat(
            request.getConversationId(),
            request.getMessage(),
            request.getTools()
        );
        
        return ResponseEntity.ok(Map.of(
            "conversationId", request.getConversationId(),
            "response", response.getResult().getOutput().getContent(),
            "metadata", response.getMetadata()
        ));
    }
    
    @GetMapping("/history/{conversationId}")
    public ResponseEntity<List<ChatMessage>> getHistory(
            @PathVariable String conversationId) {
        return ResponseEntity.ok(chatService.getConversationHistory(conversationId));
    }
    
    @DeleteMapping("/{conversationId}")
    public ResponseEntity<Void> deleteConversation(
            @PathVariable String conversationId) {
        chatService.deleteConversation(conversationId);
        return ResponseEntity.ok().build();
    }
    
    @GetMapping("/conversations")
    public ResponseEntity<List<ChatConversation>> getUserConversations(
            @RequestParam String userId) {
        return ResponseEntity.ok(
            chatService.getConversationsByUserId(userId)
        );
    }
}

record ChatRequest(
    String conversationId,
    String message,
    List<ToolCallback> tools
) {}

七、最佳实践与注意事项

7.1 性能优化

  1. 批量操作: 尽量使用批量插入而不是逐条插入
  2. 索引优化: 为常用查询字段添加索引
  3. 缓存策略: 合理使用 Redis 缓存热点数据
  4. 分页加载: 历史消息分页加载,避免一次性加载过多数据
// 分页查询示例
@Query("SELECT m FROM ChatMessage m WHERE m.conversation.conversationId = :conversationId " +
       "ORDER BY m.timestamp ASC")
Page<ChatMessage> findByConversationId(@Param("conversationId") String conversationId, 
                                       Pageable pageable);

7.2 数据安全

  1. 敏感信息过滤: 在持久化前过滤敏感信息
  2. 数据加密: 对存储的消息内容进行加密
  3. 访问控制: 实现基于用户权限的访问控制
@Component
public class SensitiveDataFilter {
    
    private static final Pattern SENSITIVE_PATTERN = 
        Pattern.compile("(password|token|secret|key)\\s*[:=]\\s*[^\\s]+", Pattern.CASE_INSENSITIVE);
    
    public String filterSensitiveData(String content) {
        return SENSITIVE_PATTERN.matcher(content).replaceAll("[REDACTED]");
    }
}

7.3 数据清理策略

@Component
@Slf4j
public class ChatDataCleanupScheduler {
    
    private final ChatConversationRepository conversationRepository;
    private final ChatMessageRepository messageRepository;
    
    @Scheduled(cron = "0 0 2 * * ?") // 每天凌晨2点执行
    @Transactional
    public void cleanupOldConversations() {
        LocalDateTime cutoffDate = LocalDateTime.now().minusDays(90);
        
        List<ChatConversation> oldConversations = 
            conversationRepository.findByUpdatedAtBefore(cutoffDate);
        
        log.info("Found {} old conversations to clean up", oldConversations.size());
        
        conversationRepository.deleteAll(oldConversations);
        
        log.info("Cleaned up {} old conversations", oldConversations.size());
    }
}

7.4 监控与日志

@Aspect
@Component
@Slf4j
public class ChatPerformanceMonitor {
    
    @Around("execution(* com.example.chat.service.ChatService.chat(..))")
    public Object monitorChatPerformance(ProceedingJoinPoint joinPoint) throws Throwable {
        long startTime = System.currentTimeMillis();
        
        try {
            Object result = joinPoint.proceed();
            long duration = System.currentTimeMillis() - startTime;
            
            log.info("Chat request completed in {} ms", duration);
            
            // 记录到监控系统
            recordMetric("chat.duration", duration);
            
            return result;
        } catch (Exception e) {
            long duration = System.currentTimeMillis() - startTime;
            log.error("Chat request failed after {} ms", duration, e);
            recordMetric("chat.errors", 1);
            throw e;
        }
    }
    
    private void recordMetric(String name, long value) {
        // 集成 Prometheus、Micrometer 等
    }
}

八、总结

本文详细介绍了 Spring AI 1.0.4 版本中对话持久化的完整实现方案,包括:

  1. 基础配置: 数据库实体、Repository、ChatMemory 接口实现
  2. 工具调用持久化: 扩展实体类、监听器、完整聊天服务
  3. 自定义持久化: Redis、MongoDB 等多种存储方案
  4. 混合策略: Redis 缓存 + 数据库持久化的混合架构
  5. 最佳实践: 性能优化、数据安全、清理策略、监控日志

通过合理的设计和实现,可以构建一个高性能、可靠的对话持久化系统,为用户提供连续的对话体验,同时支持数据分析和历史回溯。

参考资源


作者: AI 技术分享者
发布时间: 2024年
版本: Spring AI 1.0.4

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐