【Spring AI】AiMessageChatMemory 实战:简单通过MySQL数据库实现项目重启后对话状态恢复
一、背景与需求
在开发智能对话系统时,我们经常遇到一个挑战:如何在项目重启后保持对话状态。想象一下,用户正在与你的 AI 助手进行多轮对话,突然服务器重启,所有对话历史都丢失了,用户不得不重新开始,这会严重影响用户体验。
Spring AI 提供了 AiMessageChatMemory 接口,它为我们提供了管理对话历史的能力。但默认实现通常是基于内存的,项目重启后数据会丢失。本文将结合腾讯云开发者社区的相关实践,详细介绍如何实现 AiMessageChatMemory 的持久化存储,确保对话状态在项目重启后能够恢复。
二、AiMessageChatMemory 简介
AiMessageChatMemory 是 Spring AI 中用于管理对话历史的核心接口,它提供了以下主要功能:
- 存储用户和 AI 的对话消息
- 管理对话上下文长度
- 提供对话历史的访问和修改方法
默认的实现类SimpleAiMessageChatMemory是基于内存的,适合单会话场景,但不支持持久化。
三、实现持久化存储方案
1. 数据库设计
首先,我们需要设计一个数据库表来存储对话历史。
CREATE TABLE chat_memory (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
session_id VARCHAR(255) NOT NULL,
message_type VARCHAR(50) NOT NULL, -- USER, AI, SYSTEM
content TEXT NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
embedding JSON NULL, -- 可选,存储消息的嵌入向量
metadata JSON NULL -- 可选,存储额外元数据
);
CREATE INDEX idx_session_id ON chat_memory(session_id);
CREATE INDEX idx_timestamp ON chat_memory(timestamp);
2. 实现自定义 ChatMemory
创建一个实现AiMessageChatMemory接口的类
import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AiMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
@Component
public class PersistentAiMessageChatMemory implements AiMessageChatMemory {
private final DataSource dataSource;
private final String sessionId;
private final int maxMessages;
public PersistentAiMessageChatMemory(DataSource dataSource, String sessionId, int maxMessages) {
this.dataSource = dataSource;
this.sessionId = sessionId;
this.maxMessages = maxMessages;
}
@Override
public void add(UserMessage message) {
saveMessage("USER", message.getContent());
}
@Override
public void add(AiMessage message) {
saveMessage("AI", message.getContent());
}
@Override
public void add(SystemMessage message) {
saveMessage("SYSTEM", message.getContent());
}
@Override
public List<Message> getMessages() {
List<Message> messages = new ArrayList<>();
String sql = "SELECT message_type, content FROM chat_memory WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?";
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, sessionId);
stmt.setInt(2, maxMessages);
ResultSet rs = stmt.executeQuery();
List<Message> tempMessages = new ArrayList<>();
while (rs.next()) {
String type = rs.getString("message_type");
String content = rs.getString("content");
switch (type) {
case "USER":
tempMessages.add(new UserMessage(content));
break;
case "AI":
tempMessages.add(new AiMessage(content));
break;
case "SYSTEM":
tempMessages.add(new SystemMessage(content));
break;
}
}
// 反转列表,使消息按时间顺序排列
for (int i = tempMessages.size() - 1; i >= 0; i--) {
messages.add(tempMessages.get(i));
}
} catch (SQLException e) {
e.printStackTrace();
}
return messages;
}
@Override
public void clear() {
String sql = "DELETE FROM chat_memory WHERE session_id = ?";
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, sessionId);
stmt.executeUpdate();
} catch (SQLException e) {
e.printStackTrace();
}
}
private void saveMessage(String type, String content) {
String sql = "INSERT INTO chat_memory (session_id, message_type, content) VALUES (?, ?, ?)";
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, sessionId);
stmt.setString(2, type);
stmt.setString(3, content);
stmt.executeUpdate();
} catch (SQLException e) {
e.printStackTrace();
}
// 清理超出限制的旧消息
cleanupOldMessages();
}
private void cleanupOldMessages() {
String sql = "DELETE FROM chat_memory WHERE session_id = ? AND id NOT IN (" +
"SELECT id FROM chat_memory WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?" +
")";
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, sessionId);
stmt.setString(2, sessionId);
stmt.setInt(3, maxMessages);
stmt.executeUpdate();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
3. 创建 Memory 管理服务
Service服务层
import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.stereotype.Service;
import javax.sql.DataSource;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class ChatMemoryService {
private final DataSource dataSource;
private final Map<String, AiMessageChatMemory> memoryMap = new ConcurrentHashMap<>();
private static final int DEFAULT_MAX_MESSAGES = 50;
public ChatMemoryService(DataSource dataSource) {
this.dataSource = dataSource;
}
public AiMessageChatMemory getMemory(String sessionId) {
return memoryMap.computeIfAbsent(sessionId,
id -> new PersistentAiMessageChatMemory(dataSource, id, DEFAULT_MAX_MESSAGES));
}
public void removeMemory(String sessionId) {
memoryMap.remove(sessionId);
}
}
4. 在控制器中使用
import org.springframework.ai.chat.memory.AiMessageChatMemory;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AiMessage;
import org.springframework.ai.chat.ChatClient;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/chat")
public class ChatController {
private final ChatClient chatClient;
private final ChatMemoryService memoryService;
public ChatController(ChatClient chatClient, ChatMemoryService memoryService) {
this.chatClient = chatClient;
this.memoryService = memoryService;
}
@PostMapping("/message")
public String chat(@RequestParam String sessionId, @RequestParam String message) {
// 获取或创建对话记忆
AiMessageChatMemory memory = memoryService.getMemory(sessionId);
// 添加用户消息
memory.add(new UserMessage(message));
// 构建完整上下文
var prompt = memory.getMessages();
// 调用 AI 模型
var response = chatClient.call(prompt);
// 添加 AI 响应
memory.add(new AiMessage(response.getResult().getOutput().getContent()));
return response.getResult().getOutput().getContent();
}
@PostMapping("/clear")
public String clearMemory(@RequestParam String sessionId) {
AiMessageChatMemory memory = memoryService.getMemory(sessionId);
memory.clear();
memoryService.removeMemory(sessionId);
return "对话历史已清空";
}
}
四、项目重启后恢复对话
1. 会话管理
为了在项目重启后能够恢复对话,我们需要:
- 唯一会话标识 :为每个用户会话生成唯一的sessionId
- 会话状态存储 :将会话信息存储在数据库或Redis中
- 会话恢复机制 :项目启动时加载活跃会话
2. 实现会话恢复
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.HashSet;
import java.util.Set;
@Component
public class ChatMemoryInitializer implements ApplicationRunner {
private final DataSource dataSource;
private final ChatMemoryService memoryService;
public ChatMemoryInitializer(DataSource dataSource, ChatMemoryService memoryService) {
this.dataSource = dataSource;
this.memoryService = memoryService;
}
@Override
public void run(ApplicationArguments args) throws Exception {
// 加载活跃会话
Set<String> sessionIds = getActiveSessions();
for (String sessionId : sessionIds) {
// 预加载对话记忆
memoryService.getMemory(sessionId);
System.out.println("已恢复会话: " + sessionId);
}
System.out.println("共恢复 " + sessionIds.size() + " 个活跃会话");
}
private Set<String> getActiveSessions() throws Exception {
Set<String> sessionIds = new HashSet<>();
String sql = "SELECT DISTINCT session_id FROM chat_memory";
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql);
ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
sessionIds.add(rs.getString("session_id"));
}
}
return sessionIds;
}
}
五、总结
本方案基于 MySQL + Java 实现的会话状态持久化,非常适合课程项目、中小规模应用或对实时性要求不高的场景。其核心优势在于:
- 数据持久化可靠 :MySQL作为成熟的关系型数据库,提供完整的事务支持和数据一致性保障
- 开发成本低 :SQL语法直观,Java生态成熟,学习曲线平缓
- 功能完整 :满足会话管理、消息存储、历史查询等基础需求
- 可维护性强 :便于后续功能扩展
在高并发、大流量、低延迟的实战场景中,更推荐采用Redis作为缓存层 + MySQL作为持久化层的混合架构。
通过实现AiMessageChatMemory的持久化存储,我们可以确保在项目重启后能够恢复对话状态,为用户提供连续、一致的对话体验。这种方案不仅适用于智能助手类应用,也适用于任何需要保持对话上下文的场景。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)