一、背景与需求

        在开发智能对话系统时,我们经常遇到一个挑战:如何在项目重启后保持对话状态。想象一下,用户正在与你的 AI 助手进行多轮对话,突然服务器重启,所有对话历史都丢失了,用户不得不重新开始,这会严重影响用户体验。

        Spring AI 提供了 AiMessageChatMemory 接口,它为我们提供了管理对话历史的能力。但默认实现通常是基于内存的,项目重启后数据会丢失。本文将结合腾讯云开发者社区的相关实践,详细介绍如何实现 AiMessageChatMemory 的持久化存储,确保对话状态在项目重启后能够恢复。

Spring AI 聊天记忆管理:MessageWindowChatMemory 与 MessageChatMemoryAdvisor 详解-腾讯云开发者社区-腾讯云https://cloud.tencent.com/developer/article/2588729

二、AiMessageChatMemory 简介

AiMessageChatMemory 是 Spring AI 中用于管理对话历史的核心接口,它提供了以下主要功能:

  • 存储用户和 AI 的对话消息
  • 管理对话上下文长度
  • 提供对话历史的访问和修改方法

默认的实现类SimpleAiMessageChatMemory是基于内存的,适合单会话场景,但不支持持久化。

三、实现持久化存储方案

1. 数据库设计

首先,我们需要设计一个数据库表来存储对话历史。

CREATE TABLE chat_memory (
    id BIGINT PRIMARY KEY AUTO_INCREMENT,
    session_id VARCHAR(255) NOT NULL,
    message_type VARCHAR(50) NOT NULL,  -- USER, AI, SYSTEM
    content TEXT NOT NULL,
    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    embedding JSON NULL,  -- 可选,存储消息的嵌入向量
    metadata JSON NULL    -- 可选,存储额外元数据
);

CREATE INDEX idx_session_id ON chat_memory(session_id);
CREATE INDEX idx_timestamp ON chat_memory(timestamp);

2. 实现自定义 ChatMemory

创建一个实现AiMessageChatMemory接口的类

import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AiMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

@Component
public class PersistentAiMessageChatMemory implements AiMessageChatMemory {

    private final DataSource dataSource;
    private final String sessionId;
    private final int maxMessages;

    public PersistentAiMessageChatMemory(DataSource dataSource, String sessionId, int maxMessages) {
        this.dataSource = dataSource;
        this.sessionId = sessionId;
        this.maxMessages = maxMessages;
    }

    @Override
    public void add(UserMessage message) {
        saveMessage("USER", message.getContent());
    }

    @Override
    public void add(AiMessage message) {
        saveMessage("AI", message.getContent());
    }

    @Override
    public void add(SystemMessage message) {
        saveMessage("SYSTEM", message.getContent());
    }

    @Override
    public List<Message> getMessages() {
        List<Message> messages = new ArrayList<>();
        String sql = "SELECT message_type, content FROM chat_memory WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?";
        
        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sql)) {
            stmt.setString(1, sessionId);
            stmt.setInt(2, maxMessages);
            
            ResultSet rs = stmt.executeQuery();
            List<Message> tempMessages = new ArrayList<>();
            
            while (rs.next()) {
                String type = rs.getString("message_type");
                String content = rs.getString("content");
                
                switch (type) {
                    case "USER":
                        tempMessages.add(new UserMessage(content));
                        break;
                    case "AI":
                        tempMessages.add(new AiMessage(content));
                        break;
                    case "SYSTEM":
                        tempMessages.add(new SystemMessage(content));
                        break;
                }
            }
            
            // 反转列表,使消息按时间顺序排列
            for (int i = tempMessages.size() - 1; i >= 0; i--) {
                messages.add(tempMessages.get(i));
            }
            
        } catch (SQLException e) {
            e.printStackTrace();
        }
        
        return messages;
    }

    @Override
    public void clear() {
        String sql = "DELETE FROM chat_memory WHERE session_id = ?";
        
        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sql)) {
            stmt.setString(1, sessionId);
            stmt.executeUpdate();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    private void saveMessage(String type, String content) {
        String sql = "INSERT INTO chat_memory (session_id, message_type, content) VALUES (?, ?, ?)";
        
        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sql)) {
            stmt.setString(1, sessionId);
            stmt.setString(2, type);
            stmt.setString(3, content);
            stmt.executeUpdate();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        
        // 清理超出限制的旧消息
        cleanupOldMessages();
    }

    private void cleanupOldMessages() {
        String sql = "DELETE FROM chat_memory WHERE session_id = ? AND id NOT IN (" +
                     "SELECT id FROM chat_memory WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?" +
                     ")";
        
        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sql)) {
            stmt.setString(1, sessionId);
            stmt.setString(2, sessionId);
            stmt.setInt(3, maxMessages);
            stmt.executeUpdate();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

3. 创建 Memory 管理服务

Service服务层

import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.stereotype.Service;

import javax.sql.DataSource;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Service
public class ChatMemoryService {

    private final DataSource dataSource;
    private final Map<String, AiMessageChatMemory> memoryMap = new ConcurrentHashMap<>();
    private static final int DEFAULT_MAX_MESSAGES = 50;

    public ChatMemoryService(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    public AiMessageChatMemory getMemory(String sessionId) {
        return memoryMap.computeIfAbsent(sessionId, 
            id -> new PersistentAiMessageChatMemory(dataSource, id, DEFAULT_MAX_MESSAGES));
    }

    public void removeMemory(String sessionId) {
        memoryMap.remove(sessionId);
    }
}

4. 在控制器中使用

import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AiMessage;
import org.springframework.ai.chat.ChatClient;
import org.springframework.web.bind.annotation.*;

@RestController
@RequestMapping("/chat")
public class ChatController {

    private final ChatClient chatClient;
    private final ChatMemoryService memoryService;

    public ChatController(ChatClient chatClient, ChatMemoryService memoryService) {
        this.chatClient = chatClient;
        this.memoryService = memoryService;
    }

    @PostMapping("/message")
    public String chat(@RequestParam String sessionId, @RequestParam String message) {
        // 获取或创建对话记忆
        AiMessageChatMemory memory = memoryService.getMemory(sessionId);
        
        // 添加用户消息
        memory.add(new UserMessage(message));
        
        // 构建完整上下文
        var prompt = memory.getMessages();
        
        // 调用 AI 模型
        var response = chatClient.call(prompt);
        
        // 添加 AI 响应
        memory.add(new AiMessage(response.getResult().getOutput().getContent()));
        
        return response.getResult().getOutput().getContent();
    }

    @PostMapping("/clear")
    public String clearMemory(@RequestParam String sessionId) {
        AiMessageChatMemory memory = memoryService.getMemory(sessionId);
        memory.clear();
        memoryService.removeMemory(sessionId);
        return "对话历史已清空";
    }
}

四、项目重启后恢复对话

1. 会话管理

为了在项目重启后能够恢复对话,我们需要:

  1. 唯一会话标识 :为每个用户会话生成唯一的sessionId
  2. 会话状态存储 :将会话信息存储在数据库或Redis中
  3. 会话恢复机制 :项目启动时加载活跃会话

2. 实现会话恢复

import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.HashSet;
import java.util.Set;

@Component
public class ChatMemoryInitializer implements ApplicationRunner {

    private final DataSource dataSource;
    private final ChatMemoryService memoryService;

    public ChatMemoryInitializer(DataSource dataSource, ChatMemoryService memoryService) {
        this.dataSource = dataSource;
        this.memoryService = memoryService;
    }

    @Override
    public void run(ApplicationArguments args) throws Exception {
        // 加载活跃会话
        Set<String> sessionIds = getActiveSessions();
        
        for (String sessionId : sessionIds) {
            // 预加载对话记忆
            memoryService.getMemory(sessionId);
            System.out.println("已恢复会话: " + sessionId);
        }
        
        System.out.println("共恢复 " + sessionIds.size() + " 个活跃会话");
    }

    private Set<String> getActiveSessions() throws Exception {
        Set<String> sessionIds = new HashSet<>();
        String sql = "SELECT DISTINCT session_id FROM chat_memory";
        
        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sql);
             ResultSet rs = stmt.executeQuery()) {
            
            while (rs.next()) {
                sessionIds.add(rs.getString("session_id"));
            }
        }
        
        return sessionIds;
    }
}

五、总结  

        本方案基于 MySQL + Java 实现的会话状态持久化,非常适合课程项目、中小规模应用或对实时性要求不高的场景。其核心优势在于:

  • 数据持久化可靠 :MySQL作为成熟的关系型数据库,提供完整的事务支持和数据一致性保障
  • 开发成本低 :SQL语法直观,Java生态成熟,学习曲线平缓
  • 功能完整 :满足会话管理、消息存储、历史查询等基础需求
  • 可维护性强 :便于后续功能扩展

        在高并发、大流量、低延迟的实战场景中,更推荐采用Redis作为缓存层 + MySQL作为持久化层的混合架构。

        通过实现AiMessageChatMemory的持久化存储,我们可以确保在项目重启后能够恢复对话状态,为用户提供连续、一致的对话体验。这种方案不仅适用于智能助手类应用,也适用于任何需要保持对话上下文的场景。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐