图1:WebSocket长连接架构图

概述

在人工智能应用飞速发展的今天,大模型(Large Language Model, LLM)已经成为对话系统、智能客服、内容生成等场景的核心技术。然而,传统的HTTP请求-响应模式在实时交互场景中存在明显的局限性:每次交互都需要建立新的连接,传输冗长的HTTP头部信息,无法实现服务器的主动推送。这些问题在长时间、多轮次的对话场景中尤为突出,严重影响了用户体验和系统效率。

WebSocket协议的出现彻底改变了这一局面。作为一种基于TCP的全双工通信协议,WebSocket允许客户端和服务器在一次握手之后保持持久连接,双方可以随时相互发送数据,无需重新建立连接。这种特性使得WebSocket成为构建实时对话系统的理想选择,特别是与大模型结合时,能够实现流畅的流式响应、多轮上下文记忆、即时反馈等特性。

本文将深入探讨如何基于Spring Boot框架整合WebSocket协议与大模型服务,构建高性能的长连接交互服务。我们将从协议原理、技术选型、实现细节、安全策略等多个维度进行详细阐述,并通过完整的实战案例帮助读者掌握这一核心技术。

一、WebSocket vs SSE:何时选择WebSocket

图4:WebSocket与SSE适用场景对比

1.1 技术特性对比

在讨论WebSocket之前,我们首先需要理解它的主要竞争对手——Server-Sent Events(SSE)技术。SSE是一种允许服务器向浏览器推送事件的HTML5技术,它基于HTTP协议,实现了服务器到客户端的单向数据流。虽然SSE看起来功能较为简单,但在某些场景下却是更优的选择。

**WebSocket的核心优势**在于其全双工通信能力。连接建立后,客户端和服务器都可以主动发送消息,无需等待对方的响应。这种特性使得WebSocket非常适合需要频繁双向数据交换的场景,例如:实时聊天应用、多人在线游戏、协作编辑工具、金融交易系统等。在这些场景中,延迟是关键因素,而WebSocket的低开销和即时性正好满足需求。

**SSE的独特价值**则体现在其简单性和对HTTP协议的兼容性。SSE通过标准的HTTP/HTTPS协议传输,不需要特殊的协议升级,这使得它可以轻松穿越大多数防火墙和代理服务器。对于只需要服务器向客户端推送数据的场景(如消息通知、股价更新、进度报告等),SSE提供了更简洁的实现方式。此外,浏览器对SSE提供了原生支持,包括自动重连机制,无需编写复杂的连接管理代码。

1.2 通信模式与协议开销

图2:WebSocket协议握手与通信流程

从协议层面分析,WebSocket在连接建立后使用独立的帧格式传输数据,每个帧最小只需要2字节的开销。相比之下,SSE仍然基于HTTP,每次事件推送都需要携带HTTP头部信息,即使在持久连接的情况下,也会产生一定的协议开销。对于需要高频推送少量数据的场景,这个差异会累积成显著的性能影响。

WebSocket的帧结构设计非常紧凑,包含帧头和负载两部分。帧头最少只需要2字节,可以携带控制信息、掩码位和负载长度;负载部分直接承载应用数据,没有冗余的格式化信息。这种高效的编码方式使得WebSocket在数据传输效率上远超基于HTTP的轮询或SSE方案。

1.3 大模型场景的选择

在大模型应用中,我们需要考虑几个关键因素:是否需要客户端向服务器发送大量请求数据、是否需要服务器的流式响应、是否需要保持对话状态、预期的并发连接数等。

对于**大模型流式输出**场景,WebSocket和SSE都是可行的选择,但各有优劣。使用WebSocket时,服务器可以按照自己的节奏推送数据块,客户端可以立即渲染收到的内容,实现真正的“打字机”效果。使用SSE时,服务器通过`text/event-stream`内容类型推送事件,浏览器提供自动解析和重连支持。

然而,当我们需要实现**复杂的多轮对话**时,WebSocket的优势就显现出来了。对话历史需要维护在服务端,客户端需要在对话中途发送修正指令或补充信息,这些都需要双向通信能力。虽然SSE配合fetch API也可以实现类似效果,但WebSocket的API更加直观,状态管理更加清晰。

综合来看,对于大多数大模型应用场景,特别是需要复杂交互、多轮对话、双向数据交换的应用,**WebSocket是更合适的选择**。它的全双工特性、较低的协议开销、成熟的生态系统使其成为构建高性能LLM交互服务的首选。

WebSocket与SSE对比

websocket_vs_sse.png

二、WebSocket协议深度解析

2.1 握手机制详解

WebSocket连接的建立始于一个特殊的HTTP请求,这个请求包含了Upgrade头部,明确表示客户端希望将连接升级为WebSocket协议。服务器如果支持WebSocket,会返回一个HTTP 101状态码(Switching Protocols),表示协议切换成功,此后双方就可以使用WebSocket帧进行双向通信了。

一个典型的客户端握手请求包含以下关键头部:

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

GET /ws HTTP/1.1

Host: server.example.com

Upgrade: websocket

Connection: Upgrade

Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA

Sec-WebSocket-Version: 13

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

`Sec-WebSocket-Key`是一个Base64编码的随机值,由客户端生成。服务器使用这个值和固定的GUID(258EAFA5-E914-47DA-95CA-C5AB0DC85B11)进行SHA-1哈希,然后Base64编码,作为`Sec-WebSocket-Accept`头部的值返回给客户端。这个机制确保了握手过程的真实性,防止恶意请求伪装成WebSocket握手。

服务器验证Key的响应如下:

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

HTTP/1.1 101 Switching Protocols

Upgrade: websocket

Connection: Upgrade

Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYG3hZbK3xY=

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

握手的成功意味着传输层协议已经从HTTP切换为WebSocket自定义的帧协议。此后,所有的通信都基于WebSocket帧格式,不再使用HTTP协议。这种平滑的协议切换是WebSocket设计的精妙之处,它允许在80/443端口上建立连接,巧妙地穿越大多数防火墙和代理服务器。

2.2 帧结构与数据封装

WebSocket通信的基本单位是帧(Frame)。每个帧由帧头和帧负载组成,帧头包含帧的控制信息和元数据,帧负载承载实际的应用数据。

帧头的第一个字节的高4位表示操作码(Opcode),低4位包含帧的标志位。操作码定义了帧的类型:0x0表示继续帧(用于分片消息),0x1表示文本帧,0x2表示二进制帧,0x8表示关闭帧,0x9表示Ping帧,0xA表示Pong帧。帧头第二个字节的最高位是掩码位(MASK),如果设置为1,表示负载数据经过了掩码处理;剩余7位表示负载长度。

对于负载长度大于125字节的情况,使用额外的字节来编码实际长度。126表示接下来的2字节是无符号16位整数,表示负载长度;127表示接下来的8字节是无符号64位整数,表示负载长度。这种变长编码方式既紧凑又灵活,能够支持从几字节到数GB的数据传输。

**掩码机制**是WebSocket协议安全性的重要组成部分。客户端发送到服务器的所有帧都必须设置掩码位,并使用4字节的掩码密钥对负载数据进行XOR运算。服务器收到帧后,使用相同的掩码密钥进行逆运算还原数据。这一机制防止了代理服务器缓存攻击和恶意注入,确保了数据传输的安全性。

2.3 心跳保活机制

长时间存在的WebSocket连接可能因为中间网络设备的状态超时而断开。为了保持连接的活跃,现代WebSocket实现普遍采用心跳机制(Heartbeat)。心跳分为Ping和Pong两种帧类型,任何一方都可以发送Ping帧,对方必须回复Pong帧。

典型的实现是服务器定期向客户端发送Ping帧,检查客户端的响应。如果在预设时间内没有收到Pong响应,说明连接已经不可靠,应该关闭连接并触发重连逻辑。客户端也可以主动发送Ping帧,请求服务器确认连接状态。

心跳间隔的设置需要权衡。太短的心跳间隔会增加网络流量和服务器负载,太长的心跳间隔可能无法及时发现断开的连接。通常,心跳间隔设置为30秒到60秒是一个合理的范围。在大模型应用中,由于对话可能需要较长的处理时间,心跳间隔的设置还需要考虑整体的用户体验。

Spring Framework的WebSocket实现提供了配置心跳间隔的接口。通过重写`WebSocketHandler`的`afterConnectionEstablished`方法,我们可以在连接建立时启动心跳任务;通过重写`handleTransportError`方法,可以捕获传输错误并处理异常关闭的情况。

2.4 断线重连策略

网络环境复杂多变,WebSocket连接随时可能因为各种原因断开。良好的客户端实现应该具备自动重连能力,在连接断开后按照一定的策略尝试恢复连接。

**指数退避算法**是实现重连策略的常用方法。首次重连尝试在1秒后进行,如果失败,则等待2秒再尝试,然后是4秒、8秒,以此类推,直到达到最大等待时间(通常不超过30秒或60秒)。这种策略既能快速响应临时性网络故障,又能避免在持续性故障时对服务器造成过多请求压力。

在重连逻辑中,还需要考虑以下几点:

1. **重连前的延迟**:立即重连往往徒劳无功,因为网络故障可能是暂时性的

2. **最大重试次数**:设置合理的重试上限,避免无限循环

3. **用户感知**:在重连期间应该给用户适当的提示,不要让用户以为应用卡死

4. **状态恢复**:重连成功后,可能需要恢复会话状态,而不是简单地从头开始

5. **手动重连选项**:提供手动重连按钮,让不耐烦的用户可以主动触发

对于大模型应用而言,断线重连还需要考虑对话状态的恢复。如果用户在输入一段较长的文本后连接断开,重连后应该保留用户的输入内容,而不是让用户重新打字。这需要客户端实现本地状态持久化和服务器端会话状态管理。

三、Spring Boot WebSocket技术选型

图3:Spring Boot WebSocket集成大模型原理

3.1 标准WebSocket API

Java EE 7引入了标准的WebSocket API(JSR-356),Spring Boot可以原生支持这一API。使用标准API的最大好处是可移植性强,代码不依赖特定的Spring组件,可以在任何支持JSR-356的容器中运行。

使用标准WebSocket API需要创建一个或多个类,使用`@ServerEndpoint`注解标注为WebSocket服务端点。注解的`value`属性指定端点的路径,`configurator`属性可以指定配置类来自定义握手过程。

@ServerEndpoint(value = "/ws/chat", configurator = HttpSessionConfigurator.class)
public class ChatEndpoint {

    @OnOpen
    public void onOpen(Session session) {
        // 处理连接建立
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        // 处理接收到的消息
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        // 处理连接关闭
    }

    @OnError
    public void onError(Session session, Throwable error) {
        // 处理错误
    }
}

@ServerEndpoint(value = "/ws/chat", configurator = HttpSessionConfigurator.class)

public class ChatEndpoint {

    @OnOpen

    public void onOpen(Session session) {

        // 处理连接建立

    }

    @OnMessage

    public void onMessage(String message, Session session) {

        // 处理接收到的消息

    }

    @OnClose

    public void onClose(Session session, CloseReason reason) {

        // 处理连接关闭

    }

    @OnError

    public void onError(Session session, Throwable error) {

        // 处理错误

    }

}

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

标准API的局限性在于它是一个低层次的接口,只提供了帧级别的操作能力。对于复杂的企业级应用,我们往往需要更高级的功能,如:消息路由、主题订阅、用户分组、心跳管理等。这些功能在标准API中都需要自己实现,增加了开发工作量。

3.2 STOMP协议

STOMP(Simple Text Oriented Messaging Protocol)是一种面向消息的简单文本协议,设计目标是简化消息传递应用的开发。STOMP在WebSocket之上提供了一个高层抽象,使得开发者可以用发布-订阅的模式处理消息。

使用STOMP协议时,客户端和服务器通过发送消息(Message)进行通信,而不是底层的帧。消息包含头部(Headers)和负载(Body)两部分,头部可以包含目标地址、内容类型、持久化标志等元数据。

在Spring Boot中配置STOMP WebSocket需要以下几个步骤:

@ServerEndpoint(value = "/ws/chat", configurator = HttpSessionConfigurator.class)
public class ChatEndpoint {

    @OnOpen
    public void onOpen(Session session) {
        // 处理连接建立
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        // 处理接收到的消息
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        // 处理连接关闭
    }

    @OnError
    public void onError(Session session, Throwable error) {
        // 处理错误
    }
}

@Configuration

@EnableWebSocketMessageBroker

public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    @Override

    public void configureMessageBroker(MessageBrokerRegistry registry) {

        // 配置内存消息代理,用于处理订阅和广播

        registry.enableSimpleBroker("/topic", "/queue");

        // 配置应用目的地前缀,客户端发送消息到此前缀的地址

        registry.setApplicationDestinationPrefixes("/app");

    }

    @Override

    public void registerStompEndpoints(StompEndpointRegistry registry) {

        // 注册STOMP端点,客户端通过此端点建立WebSocket连接

        registry.addEndpoint("/ws")

                .setAllowedOrigins("*")

                .withSockJS();  // 提供SockJS回退支持

    }

}

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

STOMP协议的优势在于它提供了清晰的消息路由模式和成熟的设计约定。开发者不需要关心底层的连接管理和帧解析,只需要关注消息的发送和接收逻辑。此外,STOMP还支持消息持久化、事务、确认机制等企业级特性。

然而,STOMP协议也有其局限性。它的设计更偏向于传统的消息队列场景,对于需要精细控制连接和帧的场景可能显得过于笨重。此外,STOMP的心跳机制是固定的,不够灵活,定制化程度有限。

3.3 Spring WebSocket的高级特性

Spring Framework提供了丰富的WebSocket扩展功能,这些功能在不改变底层协议的情况下大幅增强了应用的实用性。

**消息拦截器**允许我们在消息处理前后执行自定义逻辑。通过实现`ChannelInterceptor`接口,我们可以对消息进行验证、日志记录、权限检查等操作。

@ServerEndpoint(value = "/ws/chat", configurator = HttpSessionConfigurator.class)
public class ChatEndpoint {

    @OnOpen
    public void onOpen(Session session) {
        // 处理连接建立
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        // 处理接收到的消息
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        // 处理连接关闭
    }

    @OnError
    public void onError(Session session, Throwable error) {
        // 处理错误
    }
}

public class AuthChannelInterceptor implements ChannelInterceptor {

    @Override

    public Message<?> preSend(Message<?> message, MessageChannel channel) {

        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(

            message, StompHeaderAccessor.class);

        if (StompCommand.CONNECT.equals(accessor.getCommand())) {

            String token = accessor.getFirstNativeHeader("Authorization");

            if (!validateToken(token)) {

                throw new IllegalArgumentException("Invalid token");

            }

        }

        return message;

    }

}

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

**WebSocketSession管理**是另一个重要特性。Spring为每个WebSocket连接创建一个Session对象,通过这个对象我们可以获取连接的元数据、发送消息、管理连接状态。在大模型应用中,我们通常需要维护Session与用户ID、对话ID的映射关系,以便在需要时能够找到特定用户的连接。

@ServerEndpoint(value = "/ws/chat", configurator = HttpSessionConfigurator.class)
public class ChatEndpoint {

    @OnOpen
    public void onOpen(Session session) {
        // 处理连接建立
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        // 处理接收到的消息
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        // 处理连接关闭
    }

    @OnError
    public void onError(Session session, Throwable error) {
        // 处理错误
    }
}

public class LLMWebSocketHandler extends TextMessageHandler {

    private final Map<String, WebSocketSession> sessionMap = new ConcurrentHashMap<>();

    @Override

    protected void handleTextMessage(WebSocketSession session, TextMessage message) {

        String payload = message.getPayload();

        // 根据session找到对应的用户和对话上下文

        String userId = (String) session.getAttributes().get("userId");

        processLLMRequest(userId, payload, session);

    }

}

GET /ws HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZAA
Sec-WebSocket-Version: 13

3.4 技术选型建议

对于大多数大模型应用场景,我推荐使用**Spring WebSocket的原生API配合自定义的会话管理**。这种方案提供了最大的灵活性和控制能力,可以根据应用需求定制心跳机制、消息格式、连接池等各个方面的行为。

具体的技术选型建议如下:

1. **简单的单向推送场景**:使用SSE,配置简单,原生浏览器支持好

2. **需要发布-订阅模式**:使用STOMP协议,享受成熟的消息模式

3. **需要精细控制的实时交互**:使用Spring WebSocket原生API

4. **混合场景**:可以使用WebSocket处理主要交互,SSE作为某些特定功能的补充

在实际项目中,很多开发者会同时使用多种技术。例如,使用WebSocket处理主要的多轮对话,同时使用SSE推送系统通知和公告。这种混合方案可以发挥各种技术的优势,但也会增加系统的复杂性,需要谨慎权衡。

Spring Boot WebSocket集成大模型原理

websocket_spring.png

四、使用Spring WebSocket实现流式输出

4.1 流式输出的核心原理

大模型的一个标志性特性是流式输出(Streaming),即模型可以边生成边输出内容,而不是等到整个响应生成完毕后才返回。这种“打字机”效果不仅提升了用户体验,还显著降低了首字节时间(Time To First Byte),让用户可以更快地看到模型开始思考的迹象。

在WebSocket场景下实现流式输出,关键在于服务器端能够持续地向客户端发送数据帧,而不是等到所有数据准备好后一次性发送。Java的`Flow` API和Spring的`Reactive`编程支持使得这一需求变得相对简单。

基本的实现思路是:当收到客户端的请求后,服务器启动一个异步任务来处理请求,这个任务在产生输出时立即通过WebSocketSession发送数据。客户端收到数据后立即渲染,不需要等待完整响应。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@Override

protected void handleTextMessage(WebSocketSession session, TextMessage message) {

    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理

    CompletableFuture.runAsync(() -> {

        try {

            // 调用大模型API,获取流式响应

            LLMClient client = llmClientFactory.getClient();

            client.streamGenerate(request, new StreamCallback() {

                @Override

                public void onToken(String token) {

                    try {

                        // 立即发送收到的token

                        session.sendMessage(new TextMessage(token));

                    } catch (IOException e) {

                        logger.error("Failed to send message", e);

                    }

                }

                @Override

                public void onComplete() {

                    try {

                        // 发送完成标记

                        session.sendMessage(new TextMessage("[DONE]"));

                    } catch (IOException e) {

                        logger.error("Failed to send completion", e);

                    }

                }

            });

        } catch (Exception e) {

            logger.error("LLM processing error", e);

        }

    });

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@GetMapping("/stream/{conversationId}")

public SseEmitter streamResponse(@PathVariable String conversationId,

                                 @RequestParam String prompt) {

    SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);

    llmService.generateStream(prompt, conversationId, new StreamObserver() {

        @Override

        public void onToken(String token) {

            try {

                emitter.send(SseEmitter.event()

                    .name("token")

                    .data(token));

            } catch (IOException e) {

                emitter.completeWithError(e);

            }

        }

        @Override

        public void onComplete() {

            emitter.complete();

        }

        @Override

        public void onError(Throwable t) {

            emitter.completeWithError(t);

        }

    });

    return emitter;

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

`SseEmitter`的优势在于其简洁的API和Spring MVC的无缝集成。开发者不需要深入了解WebSocket的帧结构,只需要关注数据的发送逻辑。同时,SseEmitter内置了超时管理和完成处理,减少了样板代码。

4.3 背压控制与资源管理

在流式输出场景中,服务器产生数据的速度可能远快于客户端消费的速度,或者客户端网络状况不佳导致数据积压。这时我们就需要背压(Backpressure)机制来平衡生产者和消费者的速度。

Spring WebSocket原生并不直接支持背压,但我们可以结合Reactor的`Flux`来实现。`Flux`是Reactive Stream规范的一个实现,提供了丰富的操作符来处理背压。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

public Flux<String> streamLLMResponse(String prompt) {

    return Flux.create(sink -> {

        StreamObserver<String> observer = new StreamObserver<>() {

            @Override

            public void onNext(String token) {

                sink.next(token);  // 发出下一个数据

            }

            @Override

            public void onError(Throwable t) {

                sink.error(t);

            }

            @Override

            public void onComplete() {

                sink.complete();

            }

        };

        llmClient.stream(prompt, observer);

        // 取消时的清理逻辑

        sink.onCancel(() -> llmClient.cancel());

    }).onBackpressureBuffer(100);  // 缓冲最多100个元素

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

4.4 错误处理与恢复

流式输出过程中的错误处理比普通请求-响应模式更加复杂。当错误发生时,可能已经有部分数据发送给了客户端,我们需要决定如何处理这种“不完整”的响应。

一种常见的策略是发送错误标记,让客户端知道发生了错误:

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@Override

public void onError(Throwable t) {

    try {

        // 发送错误信息作为最后一个数据

        session.sendMessage(new TextMessage(

            "{\"error\":\"" + t.getMessage() + "\"}"

        ));

        session.close(CloseStatus.SERVER_ERROR);

    } catch (IOException e) {

        logger.error("Failed to send error", e);

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

客户端收到错误标记后,应该优雅地处理这种异常情况,可能需要向用户展示错误消息,同时提供重试选项。

五、WebSocketSession管理与消息分发

5.1 会话存储与映射

在企业级应用中,WebSocketSession的管理是一个核心问题。我们需要维护Session ID与用户、对话、上下文等信息之间的映射关系,以便在需要时能够精确地找到并操作特定的连接。

Spring提供了`WebSocketSession`接口来表示一个WebSocket连接。每个Session都有唯一的ID,可以通过`getAttributes()`方法存储和获取自定义属性。一个常见的做法是在握手时将用户信息存入Session属性:

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

public class AuthHandshakeInterceptor implements HandshakeInterceptor {

    @Override

    public boolean beforeHandshake(ServerHttpRequest request,

                                   ServerHttpResponse response,

                                   WebSocketHandler wsHandler,

                                   Map<String, Object> attributes) {

        // 从HTTP请求中提取认证信息

        String token = extractToken(request);

        if (validateToken(token)) {

            User user = getUserFromToken(token);

            attributes.put("userId", user.getId());

            attributes.put("username", user.getName());

            return true;

        }

        return false;

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

在大模型应用中,我们还需要维护用户与对话的映射关系:

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@Service

public class ConversationSessionManager {

    private final Map<String, WebSocketSession> userSessions = new ConcurrentHashMap<>();

    private final Map<String, String> userConversations = new ConcurrentHashMap<>();

    public void registerSession(String userId, WebSocketSession session) {

        userSessions.put(userId, session);

    }

    public void bindConversation(String userId, String conversationId) {

        userConversations.put(userId, conversationId);

    }

    public WebSocketSession getSession(String userId) {

        return userSessions.get(userId);

    }

    public String getConversationId(String userId) {

        return userConversations.get(userId);

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

5.2 消息路由与分发

当WebSocket服务器需要处理多种类型的消息时,消息路由就变得至关重要。一个设计良好的消息路由系统应该能够根据消息内容自动将消息分发给正确的处理器。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

public class MessageRouter {

    private final Map<String, MessageHandler> handlers = new HashMap<>();

    public void registerHandler(String type, MessageHandler handler) {

        handlers.put(type, handler);

    }

    public void route(WebSocketSession session, String message) {

        try {

            JsonNode root = objectMapper.readTree(message);

            String type = root.get("type").asText();

            MessageHandler handler = handlers.get(type);

            if (handler != null) {

                handler.handle(session, root);

            } else {

                sendError(session, "Unknown message type: " + type);

            }

        } catch (Exception e) {

            sendError(session, "Invalid message format");

        }

    }

    private void sendError(WebSocketSession session, String error) {

        try {

            String json = "{\"type\":\"error\",\"message\":\"" + error + "\"}";

            session.sendMessage(new TextMessage(json));

        } catch (IOException e) {

            logger.error("Failed to send error", e);

        }

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

在大模型应用中,常见的消息类型包括:发起对话、发送消息、心跳检测、获取历史、切换模型等。通过消息路由,我们可以将复杂的业务逻辑分散到多个处理器中,保持代码的清晰和可维护性。

5.3 多实例部署与会话共享

当应用部署多个实例以实现负载均衡时,同一用户的WebSocket连接可能建立在不同实例上。这给会话管理带来了挑战:如果用户A连接到实例1,用户B连接到实例2,那么当实例1需要向用户B发送消息时,就找不到对应的Session了。

解决这个问题的常用方案是使用分布式会话存储。Redis是最常用的选择,它提供了高性能的键值存储和发布-订阅功能。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@Service

public class RedisSessionRegistry {

    private final StringRedisTemplate redisTemplate;

    private static final String SESSION_KEY_PREFIX = "ws:session:";

    public void registerSession(String userId, String instanceId, String sessionId) {

        String key = SESSION_KEY_PREFIX + userId;

        redisTemplate.opsForValue().set(key,

            instanceId + ":" + sessionId,

            Duration.ofHours(24));

    }

    public Optional<SessionInfo> getSession(String userId) {

        String key = SESSION_KEY_PREFIX + userId;

        String value = redisTemplate.opsForValue().get(key);

        if (value == null) {

            return Optional.empty();

        }

        String[] parts = value.split(":");

        return Optional.of(new SessionInfo(parts[0], parts[1]));

    }

    public void removeSession(String userId) {

        redisTemplate.delete(SESSION_KEY_PREFIX + userId);

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

对于消息分发,我们可以使用Redis的发布-订阅功能。当实例1需要向用户B发送消息时,它可以将消息发布到用户B专属的频道,用户B当前连接的实例(可能是实例2)会收到通知,然后通过本地的Session发送消息。

六、大模型WebSocket流式响应实战

6.1 架构设计

实现大模型的WebSocket流式响应,整体架构可以分为三个层次:

**接入层**:负责WebSocket连接的建立、管理和消息的收发。这一层处理协议层面的细节,包括握手、心跳、消息编解码等。

**业务层**:处理具体的业务逻辑,包括对话管理、上下文维护、消息路由、权限验证等。这一层是应用的核心,连接了接入层和大模型服务层。

**大模型服务层**:封装对各种大模型API的调用,处理请求格式化、响应解析、流式数据处理等。这一层应该设计为可插拔的,以支持不同的模型提供商。


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

┌─────────────────────────────────────────────────────────────────┐

│                         客户端层                                 │

│   (WebSocket原生API / stomp.js / SockJS)                        │

└─────────────────────────────────────────────────────────────────┘

                              │

                              ▼

┌─────────────────────────────────────────────────────────────────┐

│                         接入层                                   │

│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐ │

│  │ WebSocket   │  │ 心跳        │  │ 消息编解码               │ │

│  │ Handler     │  │ Manager     │  │                          │ │

│  └─────────────┘  └─────────────┘  └─────────────────────────┘ │

└─────────────────────────────────────────────────────────────────┘

                              │

                              ▼

┌─────────────────────────────────────────────────────────────────┐

│                         业务层                                   │

│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐ │

│  │ 对话管理    │  │ 上下文      │  │ 消息路由                │ │

│  │             │  │ 维护        │  │                          │ │

│  └─────────────┘  └─────────────┘  └─────────────────────────┘ │

└─────────────────────────────────────────────────────────────────┘

                              │

                              ▼

┌─────────────────────────────────────────────────────────────────┐

│                       大模型服务层                               │

│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐ │

│  │ OpenAI      │  │ 智谱        │  │ 自托管模型              │ │

│  │ 兼容接口    │  │ GLM         │  │ (vLLM/TGI)             │ │

│  └─────────────┘  └─────────────┘  └─────────────────────────┘ │

└─────────────────────────────────────────────────────────────────┘


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

6.2 对话上下文管理

多轮对话的核心在于上下文管理。每一轮对话的输入都应该包含之前的所有对话历史,这样模型才能理解对话的完整语境。

一种常见的实现是使用固定大小的上下文窗口。系统维护一个固定长度的消息列表,每次添加新消息时,如果上下文超过限制,就移除最旧的消息:

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

public class ConversationContext {

    private final int maxMessages;

    private final List<ChatMessage> messages = new LinkedList<>();

    public ConversationContext(int maxMessages) {

        this.maxMessages = maxMessages;

    }

    public void addMessage(String role, String content) {

        messages.add(new ChatMessage(role, content));

        while (messages.size() > maxMessages) {

            messages.remove(0);

        }

    }

    public List<ChatMessage> getMessages() {

        return new ArrayList<>(messages);

    }

    public String buildPrompt() {

        StringBuilder sb = new StringBuilder();

        for (ChatMessage msg : messages) {

            sb.append(msg.getRole()).append(": ").append(msg.getContent()).append("\n");

        }

        return sb.toString();

    }

    public void clear() {

        messages.clear();

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

更高级的实现可以使用滑动窗口,保留最近N轮对话而不是最近N条消息,这样可以更好地保持对话的连贯性。

6.3 流式响应的实现

以OpenAI兼容接口为例,实现流式响应需要解析Server-Sent Events(SSE)格式的数据。OpenAI的流式API返回的数据格式如下:


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"你"},"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"好"},"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}

data: [DONE]


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

public class OpenAIStreamHandler {

    private final WebClient webClient;

    public Flux<String> streamChat(String prompt, ChatOptions options) {

        Map<String, Object> requestBody = buildRequestBody(prompt, options);

        return webClient.post()

            .uri(options.getApiUrl())

            .header("Authorization", "Bearer " + options.getApiKey())

            .contentType(MediaType.APPLICATION_JSON)

            .bodyValue(requestBody)

            .retrieve()

            .bodyToFlux(String.class)

            .map(this::parseSSEData)

            .filter(Objects::nonNull)

            .takeUntil("[DONE]"::equals);

    }

    private String parseSSEData(String sseData) {

        if (sseData == null || !sseData.startsWith("data:")) {

            return null;

        }

        String json = sseData.substring(5).trim();

        if ("[DONE]".equals(json)) {

            return "[DONE]";

        }

        try {

            JsonNode node = objectMapper.readTree(json);

            JsonNode delta = node.path("choices").get(0).path("delta");

            return delta.path("content").asText(null);

        } catch (Exception e) {

            return null;

        }

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

6.4 完整的WebSocket处理器

将以上各部分组合起来,我们可以构建一个完整的大模型WebSocket处理器:

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
    String request = message.getPayload();

    // 使用 CompletableFuture 进行异步处理
    CompletableFuture.runAsync(() -> {
        try {
            // 调用大模型API,获取流式响应
            LLMClient client = llmClientFactory.getClient();
            client.streamGenerate(request, new StreamCallback() {
                @Override
                public void onToken(String token) {
                    try {
                        // 立即发送收到的token
                        session.sendMessage(new TextMessage(token));
                    } catch (IOException e) {
                        logger.error("Failed to send message", e);
                    }
                }

                @Override
                public void onComplete() {
                    try {
                        // 发送完成标记
                        session.sendMessage(new TextMessage("[DONE]"));
                    } catch (IOException e) {
                        logger.error("Failed to send completion", e);
                    }
                }
            });
        } catch (Exception e) {
            logger.error("LLM processing error", e);
        }
    });
}

@Component

public class LLMWebSocketHandler extends TextMessageHandler {

    private final SessionRegistry sessionRegistry;

    private final LLMService llmService;

    private final ConversationManager conversationManager;

    @Override

    protected void handleTextMessage(WebSocketSession session, TextMessage message) {

        try {

            JsonNode json = objectMapper.readTree(message.getPayload());

            String type = json.get("type").asText();

            switch (type) {

                case "chat":

                    handleChat(session, json);

                    break;

                case "ping":

                    session.sendMessage(new TextMessage("{\"type\":\"pong\"}"));

                    break;

                case "clear":

                    handleClear(session);

                    break;

                default:

                    sendError(session, "Unknown message type");

            }

        } catch (Exception e) {

            sendError(session, "Invalid message format");

        }

    }

    private void handleChat(WebSocketSession session, JsonNode json) {

        String prompt = json.get("prompt").asText();

        String conversationId = getOrCreateConversationId(session);

        conversationManager.addMessage(conversationId, "user", prompt);

        llmService.streamChat(conversationId, prompt)

            .subscribe(

                token -> {

                    try {

                        String data = "{\"type\":\"chunk\",\"data\":\"" +

                            escapeJson(token) + "\"}";

                        session.sendMessage(new TextMessage(data));

                    } catch (IOException e) {

                        logger.error("Failed to send chunk", e);

                    }

                },

                error -> {

                    sendError(session, error.getMessage());

                },

                () -> {

                    try {

                        session.sendMessage(new TextMessage("{\"type\":\"done\"}"));

                    } catch (IOException e) {

                        logger.error("Failed to send done", e);

                    }

                }

            );

    }

    private String getOrCreateConversationId(WebSocketSession session) {

        String userId = (String) session.getAttributes().get("userId");

        return conversationManager.getOrCreate(userId);

    }

    @Override

    public void afterConnectionEstablished(WebSocketSession session) {

        sessionRegistry.register(session);

        logger.info("WebSocket connection established: {}", session.getId());

    }

    @Override

    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {

        sessionRegistry.unregister(session);

        logger.info("WebSocket connection closed: {}", session.getId());

    }

}


### 4.2 使用SseEmitter实现流式响应

除了直接操作WebSocketSession,Spring还提供了`SseEmitter`作为流式输出的高级抽象。虽然SseEmitter严格来说是Server-Sent Events的实现,但其原理与WebSocket流式输出非常相似,而且可以无缝转换为WebSocket。

WebSocket长连接架构

websocket_architecture.png

七、前端WebSocket客户端实现

7.1 原生WebSocket API

现代浏览器都内置了WebSocket API,可以直接使用,无需引入任何第三方库。原生API的设计非常简洁,核心就是`WebSocket`类和一些事件回调。

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

class LLMWebSocketClient {

    constructor(url) {

        this.url = url;

        this.ws = null;

        this.reconnectAttempts = 0;

        this.maxReconnectAttempts = 5;

        this.reconnectDelay = 1000;

        this.messageHandlers = [];

    }

    connect() {

        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {

            console.log('WebSocket connected');

            this.reconnectAttempts = 0;

            this.onConnected && this.onConnected();

        };

        this.ws.onmessage = (event) => {

            const data = JSON.parse(event.data);

            this.handleMessage(data);

        };

        this.ws.onerror = (error) => {

            console.error('WebSocket error:', error);

            this.onError && this.onError(error);

        };

        this.ws.onclose = (event) => {

            console.log('WebSocket closed:', event.code, event.reason);

            this.attemptReconnect();

        };

    }

    handleMessage(data) {

        switch (data.type) {

            case 'chunk':

                this.onToken && this.onToken(data.data);

                break;

            case 'done':

                this.onComplete && this.onComplete();

                break;

            case 'error':

                this.onError && this.onError(new Error(data.message));

                break;

            case 'pong':

                // 心跳响应

                break;

        }

        this.messageHandlers.forEach(handler => handler(data));

    }

    send(data) {

        if (this.ws && this.ws.readyState === WebSocket.OPEN) {

            this.ws.send(JSON.stringify(data));

        }

    }

    sendChat(prompt) {

        this.send({

            type: 'chat',

            prompt: prompt

        });

    }

    ping() {

        this.send({ type: 'ping' });

    }

    attemptReconnect() {

        if (this.reconnectAttempts >= this.maxReconnectAttempts) {

            console.log('Max reconnection attempts reached');

            this.onMaxReconnectAttemptsReached &&

                this.onMaxReconnectAttemptsReached();

            return;

        }

        this.reconnectAttempts++;

        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);

        setTimeout(() => this.connect(), delay);

    }

    disconnect() {

        this.maxReconnectAttempts = 0;  // 防止自动重连

        this.ws && this.ws.close();

    }

    on(event, handler) {

        if (event === 'token') this.onToken = handler;

        else if (event === 'complete') this.onComplete = handler;

        else if (event === 'error') this.onError = handler;

        else if (event === 'connected') this.onConnected = handler;

        else if (event === 'maxReconnectAttemptsReached') {

            this.onMaxReconnectAttemptsReached = handler;

        }

    }

}


使用这个客户端的示例:

使用这个客户端的示例:

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

const client = new LLMWebSocketClient('wss://api.example.com/ws/llm');

// 设置回调

client.on('connected', () => {

    console.log('Connected to LLM server');

});

client.on('token', (token) => {

    // 追加收到的token到显示区域

    responseContainer.innerHTML += token;

});

client.on('complete', () => {

    console.log('Response complete');

});

client.on('error', (error) => {

    console.error('Error:', error.message);

    alert('Error: ' + error.message);

});

// 建立连接

client.connect();

// 发送消息

sendButton.addEventListener('click', () => {

    const prompt = inputField.value;

    client.sendChat(prompt);

});


使用这个客户端的示例:

7.2 SockJS与STOMP.js

在实际项目中,特别是需要兼容老旧浏览器或需要更高级的消息功能时,SockJS和STOMP.js是常用的选择。SockJS提供了WebSocket的回退方案,当WebSocket不可用时,会自动降级为长轮询等其他技术;STOMP.js则提供了发布-订阅模式的消息API。

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

// 使用STOMP over WebSocket

const socket = new SockJS('/ws');

const stompClient = Stomp.over(socket);

stompClient.connect(

    {},

    (frame) => {

        console.log('Connected:', frame);

        // 订阅主题

        stompClient.subscribe('/topic/response/' + conversationId,

            (message) => {

                const data = JSON.parse(message.body);

                if (data.type === 'chunk') {

                    appendToken(data.content);

                } else if (data.type === 'done') {

                    markComplete();

                }

            }

        );

    },

    (error) => {

        console.error('Connection error:', error);

    }

);

// 发送消息

function sendChat(prompt) {

    stompClient.send(

        '/app/chat',

        {},

        JSON.stringify({

            conversationId: conversationId,

            prompt: prompt

        })

    );

}


使用这个客户端的示例:

7.3 消息显示与界面交互

用户界面的设计直接影响用户体验。在大模型对话界面中,需要考虑以下几点:

**实时反馈**:用户输入时应该有即时的视觉反馈,让用户知道输入已被接收。发送按钮应该在点击后变为加载状态,直到收到完成信号。

**渐进式渲染**:收到的token应该立即显示,不需要等待特定的分隔符或完成信号。可以使用`requestAnimationFrame`来优化渲染性能:

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

let pendingText = '';

let rafId = null;

function scheduleRender() {

    if (rafId === null) {

        rafId = requestAnimationFrame(() => {

            responseContainer.innerHTML += pendingText;

            pendingText = '';

            rafId = null;

            // 自动滚动到底部

            responseContainer.scrollTop = responseContainer.scrollHeight;

        });

    }

}

client.on('token', (token) => {

    pendingText += token;

    scheduleRender();

});

client.on('complete', () => {

    if (rafId !== null) {

        cancelAnimationFrame(rafId);

        rafId = null;

    }

    if (pendingText) {

        responseContainer.innerHTML += pendingText;

        pendingText = '';

    }

});


使用这个客户端的示例:

**Markdown渲染**:大模型的输出通常是Markdown格式,需要在显示前进行渲染。可以使用`marked.js`库来解析Markdown,使用`highlight.js`来高亮代码块:

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

client.on('complete', () => {

    const html = marked.parse(responseContainer.innerHTML);

    responseContainer.innerHTML = DOMPurify.sanitize(html);

    // 高亮代码块

    responseContainer.querySelectorAll('pre code').forEach((block) => {

        hljs.highlightElement(block);

    });

});


使用这个客户端的示例:

**错误恢复**:当发生错误时,应该给出清晰的错误提示,同时提供重试选项。错误消息不应该替换已显示的部分内容,而是追加显示:

class LLMWebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectAttempts = 5;
        this.reconnectDelay = 1000;
        this.messageHandlers = [];
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = (event) => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;
            this.onConnected && this.onConnected();
        };

        this.ws.onmessage = (event) => {
            const data = JSON.parse(event.data);
            this.handleMessage(data);
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
            this.onError && this.onError(error);
        };

        this.ws.onclose = (event) => {
            console.log('WebSocket closed:', event.code, event.reason);
            this.attemptReconnect();
        };
    }

    handleMessage(data) {
        switch (data.type) {
            case 'chunk':
                this.onToken && this.onToken(data.data);
                break;
            case 'done':
                this.onComplete && this.onComplete();
                break;
            case 'error':
                this.onError && this.onError(new Error(data.message));
                break;
            case 'pong':
                // 心跳响应
                break;
        }

        this.messageHandlers.forEach(handler => handler(data));
    }

    send(data) {
        if (this.ws && this.ws.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify(data));
        }
    }

    sendChat(prompt) {
        this.send({
            type: 'chat',
            prompt: prompt
        });
    }

    ping() {
        this.send({ type: 'ping' });
    }

    attemptReconnect() {
        if (this.reconnectAttempts >= this.maxReconnectAttempts) {
            console.log('Max reconnection attempts reached');
            this.onMaxReconnectAttemptsReached &&
                this.onMaxReconnectAttemptsReached();
            return;
        }

        this.reconnectAttempts++;
        const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);

        console.log(`Attempting reconnection in ${delay}ms...`);
        setTimeout(() => this.connect(), delay);
    }

    disconnect() {
        this.maxReconnectAttempts = 0;  // 防止自动重连
        this.ws && this.ws.close();
    }

    on(event, handler) {
        if (event === 'token') this.onToken = handler;
        else if (event === 'complete') this.onComplete = handler;
        else if (event === 'error') this.onError = handler;
        else if (event === 'connected') this.onConnected = handler;
        else if (event === 'maxReconnectAttemptsReached') {
            this.onMaxReconnectAttemptsReached = handler;
        }
    }
}

client.on('error', (error) => {

    const errorDiv = document.createElement('div');

    errorDiv.className = 'error-message';

    errorDiv.textContent = 'Error: ' + error.message;

    responseContainer.appendChild(errorDiv);

    // 显示重试按钮

    const retryBtn = document.createElement('button');

    retryBtn.textContent = 'Retry';

    retryBtn.onclick = () => {

        errorDiv.remove();

        retryBtn.remove();

        sendChat(lastPrompt);

    };

    responseContainer.appendChild(retryBtn);

});


使用这个客户端的示例:

WebSocket协议握手与通信流程

websocket_protocol.png

八、并发连接管理与心跳检测

8.1 连接池与资源管理

在高并发场景下,WebSocket连接的资源管理尤为重要。每个连接都会占用文件描述符、内存和一定的CPU时间。如果不对连接数量进行限制,可能导致服务器资源耗尽。

Spring Boot提供了配置WebSocket连接池大小的方法:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Configuration

public class WebSocketConfig {

    @Bean

    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {

        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();

        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小

        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小

        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间

        return container;

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

除了配置参数,我们还需要在代码层面实现连接数量限制:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Service

public class ConnectionLimiter {

    private final AtomicInteger connectionCount = new AtomicInteger(0);

    private final int maxConnections;

    public ConnectionLimiter(int maxConnections) {

        this.maxConnections = maxConnections;

    }

    public boolean tryAcquire() {

        return connectionCount.incrementAndGet() <= maxConnections;

    }

    public void release() {

        connectionCount.decrementAndGet();

    }

    public int getCurrentConnections() {

        return connectionCount.get();

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

8.2 心跳机制实现

心跳是保持连接活跃的重要手段。合理的心跳策略可以及时发现断开的连接,释放相关资源。

**服务器端心跳**:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Component

public class HeartbeatManager {

    private final Map<String, ScheduledFuture<?>> heartbeatTasks = new ConcurrentHashMap<>();

    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2);

    private static final long HEARTBEAT_INTERVAL = 30;  // 秒

    private static final long HEARTBEAT_TIMEOUT = 10;    // 秒

    public void startHeartbeat(WebSocketSession session) {

        ScheduledFuture<?> task = scheduler.scheduleAtFixedRate(() -> {

            if (session.isOpen()) {

                try {

                    // 发送Ping帧

                    session.sendMessage(new PingMessage());

                } catch (IOException e) {

                    logger.warn("Failed to send ping to session {}", session.getId());

                    closeSession(session);

                }

            }

        }, HEARTBEAT_INTERVAL, HEARTBEAT_INTERVAL, TimeUnit.SECONDS);

        heartbeatTasks.put(session.getId(), task);

    }

    public void stopHeartbeat(String sessionId) {

        ScheduledFuture<?> task = heartbeatTasks.remove(sessionId);

        if (task != null) {

            task.cancel(true);

        }

    }

    public void closeSession(WebSocketSession session) {

        stopHeartbeat(session.getId());

        try {

            session.close(CloseStatus.GOING_AWAY);

        } catch (IOException e) {

            logger.error("Error closing session", e);

        }

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

**客户端心跳**:

客户端需要能够响应服务器的心跳(Ping帧),同时也可以主动向服务器发送心跳:

class HeartbeatClient {
    constructor(wsClient) {
        this.wsClient = wsClient;
        this.pingInterval = null;
        this.pongTimeout = null;
        this.heartbeatInterval = 30000;  // 30秒
        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开
    }

    start() {
        this.pingInterval = setInterval(() => {
            this.sendPing();
        }, this.heartbeatInterval);
    }

    sendPing() {
        this.wsClient.ping();

        // 启动Pong超时计时器
        this.pongTimeout = setTimeout(() => {
            console.warn('Pong timeout, reconnecting...');
            this.wsClient.disconnect();
            this.wsClient.connect();
        }, this.pongTimeoutMs);
    }

    onPong() {
        // 收到Pong响应,清除超时计时器
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }

    stop() {
        if (this.pingInterval) {
            clearInterval(this.pingInterval);
            this.pingInterval = null;
        }
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }
}

class HeartbeatClient {

    constructor(wsClient) {

        this.wsClient = wsClient;

        this.pingInterval = null;

        this.pongTimeout = null;

        this.heartbeatInterval = 30000;  // 30秒

        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开

    }

    start() {

        this.pingInterval = setInterval(() => {

            this.sendPing();

        }, this.heartbeatInterval);

    }

    sendPing() {

        this.wsClient.ping();

        // 启动Pong超时计时器

        this.pongTimeout = setTimeout(() => {

            console.warn('Pong timeout, reconnecting...');

            this.wsClient.disconnect();

            this.wsClient.connect();

        }, this.pongTimeoutMs);

    }

    onPong() {

        // 收到Pong响应,清除超时计时器

        if (this.pongTimeout) {

            clearTimeout(this.pongTimeout);

            this.pongTimeout = null;

        }

    }

    stop() {

        if (this.pingInterval) {

            clearInterval(this.pingInterval);

            this.pingInterval = null;

        }

        if (this.pongTimeout) {

            clearTimeout(this.pongTimeout);

            this.pongTimeout = null;

        }

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

8.3 连接健康检查

对于运维和监控而言,了解当前连接的健康状况非常重要。我们可以实现一个健康检查接口,返回当前连接的各种指标:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@RestController

@RequestMapping("/api/ws")

public class WebSocketHealthController {

    private final ConnectionRegistry registry;

    private final HeartbeatManager heartbeatManager;

    @GetMapping("/health")

    public Map<String, Object> getHealth() {

        Map<String, Object> health = new HashMap<>();

        health.put("totalConnections", registry.getConnectionCount());

        health.put("activeConnections", registry.getActiveConnectionCount());

        health.put("connectionsByUser", registry.getConnectionsByUser());

        health.put("averageSessionDuration", registry.getAverageSessionDuration());

        health.put("timestamp", System.currentTimeMillis());

        return health;

    }

    @GetMapping("/sessions")

    public List<SessionInfo> getSessions() {

        return registry.getAllSessions().stream()

            .map(session -> new SessionInfo(

                session.getId(),

                session.getAttributes().get("userId"),

                session.getOpenTime(),

                session.isOpen()

            ))

            .collect(Collectors.toList());

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

九、安全认证与Token刷新机制

9.1 认证机制设计

WebSocket连接的认证比HTTP请求更加复杂,因为WebSocket的握手阶段使用的是HTTP协议,我们可以利用这一特性在握手时进行认证。

**方案一:URL参数认证**

在建立WebSocket连接时,将认证令牌放在URL参数中:

class HeartbeatClient {
    constructor(wsClient) {
        this.wsClient = wsClient;
        this.pingInterval = null;
        this.pongTimeout = null;
        this.heartbeatInterval = 30000;  // 30秒
        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开
    }

    start() {
        this.pingInterval = setInterval(() => {
            this.sendPing();
        }, this.heartbeatInterval);
    }

    sendPing() {
        this.wsClient.ping();

        // 启动Pong超时计时器
        this.pongTimeout = setTimeout(() => {
            console.warn('Pong timeout, reconnecting...');
            this.wsClient.disconnect();
            this.wsClient.connect();
        }, this.pongTimeoutMs);
    }

    onPong() {
        // 收到Pong响应,清除超时计时器
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }

    stop() {
        if (this.pingInterval) {
            clearInterval(this.pingInterval);
            this.pingInterval = null;
        }
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }
}

const ws = new WebSocket('wss://api.example.com/ws?token=' + authToken);


除了配置参数,我们还需要在代码层面实现连接数量限制:

服务器端在握手拦截器中验证令牌:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

public class TokenAuthHandshakeInterceptor implements HandshakeInterceptor {

    @Override

    public boolean beforeHandshake(ServerHttpRequest request,

                                   ServerHttpResponse response,

                                   WebSocketHandler handler,

                                   Map<String, Object> attributes) {

        String uri = request.getURI().toString();

        String token = extractTokenFromUri(uri);

        if (token == null || !tokenService.validateToken(token)) {

            return false;

        }

        User user = tokenService.getUserFromToken(token);

        attributes.put("userId", user.getId());

        attributes.put("userRole", user.getRole());

        return true;

    }

    private String extractTokenFromUri(String uri) {

        int idx = uri.indexOf("token=");

        if (idx == -1) return null;

        return uri.substring(idx + 6).split("&")[0];

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

这种方案的优点是简单直接,缺点是令牌会出现在浏览器历史记录、服务器日志等地方,安全性略低。

**方案二:Cookie认证**

如果用户已经通过HTTP请求登录,认证信息可能存储在Cookie中。握手时可以从Cookie中读取认证信息:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

public class CookieAuthHandshakeInterceptor implements HandshakeInterceptor {

    @Override

    public boolean beforeHandshake(ServerHttpRequest request,

                                   ServerHttpResponse response,

                                   WebSocketHandler handler,

                                   Map<String, Object> attributes) {

        HttpServletRequest httpRequest = (HttpServletRequest) request;

        Cookie[] cookies = httpRequest.getCookies();

        if (cookies == null) return false;

        String token = Arrays.stream(cookies)

            .filter(c -> "auth_token".equals(c.getName()))

            .findFirst()

            .map(Cookie::getValue)

            .orElse(null);

        if (token == null || !tokenService.validateToken(token)) {

            return false;

        }

        User user = tokenService.getUserFromToken(token);

        attributes.put("userId", user.getId());

        return true;

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

**方案三:STOMP头部认证**

使用STOMP协议时,可以在CONNECT帧中携带认证信息:

class HeartbeatClient {
    constructor(wsClient) {
        this.wsClient = wsClient;
        this.pingInterval = null;
        this.pongTimeout = null;
        this.heartbeatInterval = 30000;  // 30秒
        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开
    }

    start() {
        this.pingInterval = setInterval(() => {
            this.sendPing();
        }, this.heartbeatInterval);
    }

    sendPing() {
        this.wsClient.ping();

        // 启动Pong超时计时器
        this.pongTimeout = setTimeout(() => {
            console.warn('Pong timeout, reconnecting...');
            this.wsClient.disconnect();
            this.wsClient.connect();
        }, this.pongTimeoutMs);
    }

    onPong() {
        // 收到Pong响应,清除超时计时器
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }

    stop() {
        if (this.pingInterval) {
            clearInterval(this.pingInterval);
            this.pingInterval = null;
        }
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }
}

stompClient.connect(

    {

        'Authorization': 'Bearer ' + token

    },

    (frame) => {

        console.log('Connected');

    }

);


除了配置参数,我们还需要在代码层面实现连接数量限制:

服务器端通过`ChannelInterceptor`拦截CONNECT帧并验证:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

public class StompAuthInterceptor implements ChannelInterceptor {

    @Override

    public Message<?> preSend(Message<?> message, MessageChannel channel) {

        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(

            message, StompHeaderAccessor.class);

        if (StompCommand.CONNECT.equals(accessor.getCommand())) {

            String authHeader = accessor.getFirstNativeHeader("Authorization");

            if (authHeader == null || !authHeader.startsWith("Bearer ")) {

                throw new IllegalArgumentException("Missing or invalid Authorization header");

            }

            String token = authHeader.substring(7);

            if (!tokenService.validateToken(token)) {

                throw new IllegalArgumentException("Invalid token");

            }

            User user = tokenService.getUserFromToken(token);

            accessor.getSessionAttributes().put("userId", user.getId());

        }

        return message;

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

9.2 Token刷新机制

短期访问令牌(如JWT)通常只有几十分钟的有效期。在长时间运行的WebSocket连接中,需要在令牌过期前刷新它。

一种实现方案是:当客户端收到接近过期的警告时,先请求刷新令牌,然后用新令牌重新建立连接:

class HeartbeatClient {
    constructor(wsClient) {
        this.wsClient = wsClient;
        this.pingInterval = null;
        this.pongTimeout = null;
        this.heartbeatInterval = 30000;  // 30秒
        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开
    }

    start() {
        this.pingInterval = setInterval(() => {
            this.sendPing();
        }, this.heartbeatInterval);
    }

    sendPing() {
        this.wsClient.ping();

        // 启动Pong超时计时器
        this.pongTimeout = setTimeout(() => {
            console.warn('Pong timeout, reconnecting...');
            this.wsClient.disconnect();
            this.wsClient.connect();
        }, this.pongTimeoutMs);
    }

    onPong() {
        // 收到Pong响应,清除超时计时器
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }

    stop() {
        if (this.pingInterval) {
            clearInterval(this.pingInterval);
            this.pingInterval = null;
        }
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }
}

class TokenRefreshClient {

    constructor(wsClient, tokenService) {

        this.wsClient = wsClient;

        this.tokenService = tokenService;

        this.refreshThreshold = 5 * 60 * 1000;  // 过期前5分钟刷新

        this.refreshTimer = null;

    }

    start() {

        this.scheduleRefresh();

    }

    scheduleRefresh() {

        const expiresAt = this.tokenService.getExpiresAt();

        const now = Date.now();

        const delay = expiresAt - now - this.refreshThreshold;

        if (delay <= 0) {

            this.refresh();

        } else {

            this.refreshTimer = setTimeout(() => this.refresh(), delay);

        }

    }

    async refresh() {

        try {

            const newToken = await this.tokenService.refresh();

            console.log('Token refreshed successfully');

            // 重新连接

            this.wsClient.disconnect();

            this.wsClient.connect();

            // 安排下一次刷新

            this.scheduleRefresh();

        } catch (error) {

            console.error('Token refresh failed:', error);

            // 可能需要重新登录

            this.onRefreshFailed && this.onRefreshFailed(error);

        }

    }

    stop() {

        if (this.refreshTimer) {

            clearTimeout(this.refreshTimer);

        }

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

服务器端需要在响应中包含令牌过期时间,以便客户端计算刷新时机:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

public class TokenService {

    public String generateToken(User user) {

        Date now = new Date();

        Date expiry = new Date(now.getTime() + tokenValidityMs);

        return Jwts.builder()

            .setSubject(user.getId())

            .claim("role", user.getRole())

            .setIssuedAt(now)

            .setExpiration(expiry)

            .signWith(key)

            .compact();

    }

    public long getExpiresAt(String token) {

        Claims claims = Jwts.parserBuilder()

            .setSigningKey(key)

            .build()

            .parseClaimsJws(token)

            .getBody();

        return claims.getExpiration().getTime();

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

9.3 权限控制

WebSocket消息的权限控制与HTTP请求类似,但需要特别注意几点:

1. **握手时的权限验证**:连接建立时就应该验证用户权限,而不是等到收到第一条消息才验证

2. **消息级别的权限检查**:对于敏感操作,每条消息都应该验证权限

3. **订阅主题的权限控制**:使用STOMP时,需要验证用户是否有权订阅特定主题

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

public class PermissionChannelInterceptor implements ChannelInterceptor {

    @Override

    public Message<?> preSend(Message<?> message, MessageChannel channel) {

        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(

            message, StompHeaderAccessor.class);

        if (StompCommand.SUBSCRIBE.equals(accessor.getCommand())) {

            String destination = accessor.getDestination();

            String userId = (String) accessor.getSessionAttributes().get("userId");

            if (!permissionService.canSubscribe(userId, destination)) {

                throw new IllegalAccessException(

                    "User " + userId + " cannot subscribe to " + destination);

            }

        }

        if (StompCommand.SEND.equals(accessor.getCommand())) {

            String destination = accessor.getDestination();

            String userId = (String) accessor.getSessionAttributes().get("userId");

            if (!permissionService.canSend(userId, destination)) {

                throw new IllegalAccessException(

                    "User " + userId + " cannot send to " + destination);

            }

        }

        return message;

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

十、完整案例:WebSocket+大模型实现实时对话

10.1 项目结构

让我们构建一个完整的大模型实时对话系统。项目采用Spring Boot + WebSocket + Vue.js的架构:


除了配置参数,我们还需要在代码层面实现连接数量限制:

llm-websocket-chat/

├── backend/

│   ├── src/main/java/com/example/llmchat/

│   │   ├── LlmChatApplication.java

│   │   ├── config/

│   │   │   ├── WebSocketConfig.java

│   │   │   ├── CorsConfig.java

│   │   │   └── SecurityConfig.java

│   │   ├── websocket/

│   │   │   ├── LLMWebSocketHandler.java

│   │   │   ├── AuthHandshakeInterceptor.java

│   │   │   └── HeartbeatInterceptor.java

│   │   ├── service/

│   │   │   ├── LLMService.java

│   │   │   ├── OpenAIClient.java

│   │   │   ├── ConversationService.java

│   │   │   └── TokenService.java

│   │   ├── model/

│   │   │   ├── ChatMessage.java

│   │   │   ├── Conversation.java

│   │   │   └── User.java

│   │   └── controller/

│   │       └── ChatController.java

│   └── src/main/resources/

│       └── application.yml

└── frontend/

    ├── index.html

    ├── src/

    │   ├── main.js

    │   ├── App.vue

    │   ├── components/

    │   │   ├── ChatWindow.vue

    │   │   ├── MessageBubble.vue

    │   │   └── InputArea.vue

    │   └── services/

    │       └── websocket.js

    └── package.json


除了配置参数,我们还需要在代码层面实现连接数量限制:

10.2 后端实现

**WebSocket配置**:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Configuration

@EnableWebSocket

public class WebSocketConfig implements WebSocketConfigurer {

    private final LLMWebSocketHandler llmWebSocketHandler;

    private final AuthHandshakeInterceptor authInterceptor;

    private final HeartbeatInterceptor heartbeatInterceptor;

    @Override

    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {

        registry.addHandler(llmWebSocketHandler, "/ws/llm")

            .addInterceptors(authInterceptor, heartbeatInterceptor)

            .setAllowedOrigins("*");

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

**核心WebSocket处理器**:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Component

@Slf4j

public class LLMWebSocketHandler extends TextMessageHandler {

    private final SessionRegistry sessionRegistry;

    private final LLMService llmService;

    private final ConversationService conversationService;

    private final ObjectMapper objectMapper;

    @Override

    protected void handleTextMessage(WebSocketSession session, TextMessage message) {

        try {

            ChatRequest request = objectMapper.readValue(

                message.getPayload(), ChatRequest.class);

            switch (request.getType()) {

                case "chat":

                    handleChat(session, request);

                    break;

                case "ping":

                    sendPong(session);

                    break;

                case "clear":

                    handleClear(session, request);

                    break;

                default:

                    sendError(session, "Unknown request type");

            }

        } catch (Exception e) {

            log.error("Error handling message", e);

            sendError(session, "Invalid request: " + e.getMessage());

        }

    }

    private void handleChat(WebSocketSession session, ChatRequest request) {

        String userId = getUserId(session);

        String conversationId = request.getConversationId();

        if (conversationId == null) {

            conversationId = conversationService.createConversation(userId);

        }

        // 添加用户消息

        conversationService.addMessage(conversationId, "user", request.getContent());

        // 构建带上下文的prompt

        String prompt = conversationService.buildPrompt(conversationId);

        // 流式调用大模型

        llmService.streamGenerate(prompt, new StreamObserver() {

            private final StringBuilder response = new StringBuilder();

            @Override

            public void onToken(String token) {

                response.append(token);

                sendChunk(session, token);

            }

            @Override

            public void onComplete() {

                // 保存助手回复

                conversationService.addMessage(

                    conversationId, "assistant", response.toString());

                sendDone(session, conversationId);

            }

            @Override

            public void onError(Throwable error) {

                sendError(session, error.getMessage());

            }

        });

    }

    private void sendChunk(WebSocketSession session, String chunk) {

        try {

            String json = objectMapper.writeValueAsString(

                new ChatResponse("chunk", chunk, null));

            session.sendMessage(new TextMessage(json));

        } catch (IOException e) {

            log.error("Failed to send chunk", e);

        }

    }

    private void sendDone(WebSocketSession session, String conversationId) {

        try {

            String json = objectMapper.writeValueAsString(

                new ChatResponse("done", null, conversationId));

            session.sendMessage(new TextMessage(json));

        } catch (IOException e) {

            log.error("Failed to send done", e);

        }

    }

    private void sendError(WebSocketSession session, String error) {

        try {

            String json = objectMapper.writeValueAsString(

                new ChatResponse("error", error, null));

            session.sendMessage(new TextMessage(json));

        } catch (IOException e) {

            log.error("Failed to send error", e);

        }

    }

    private String getUserId(WebSocketSession session) {

        return (String) session.getAttributes().get("userId");

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

**大模型服务**:

@Configuration
public class WebSocketConfig {

    @Bean
    public ServletServerContainerFactoryBean createServletServerContainerFactoryBean() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);   // 文本消息缓冲区大小
        container.setMaxBinaryMessageBufferSize(8192); // 二进制消息缓冲区大小
        container.setMaxSessionIdleTimeout(3600000L);   // 会话最大空闲时间
        return container;
    }
}

@Service

@Slf4j

public class LLMService {

    private final WebClient webClient;

    private final ObjectMapper objectMapper;

    public LLMService() {

        this.webClient = WebClient.builder()

            .baseUrl("https://api.openai.com")

            .defaultHeader("Content-Type", "application/json")

            .build();

        this.objectMapper = new ObjectMapper();

    }

    public void streamGenerate(String prompt, StreamObserver observer) {

        Map<String, Object> requestBody = Map.of(

            "model", "gpt-3.5-turbo",

            "messages", List.of(Map.of("role", "user", "content", prompt)),

            "stream", true

        );

        webClient.post()

            .uri("/v1/chat/completions")

            .header("Authorization", "Bearer " + getApiKey())

            .bodyValue(requestBody)

            .retrieve()

            .bodyToFlux(String.class)

            .subscribe(

                data -> {

                    String token = parseStreamToken(data);

                    if (token != null) {

                        observer.onToken(token);

                    }

                },

                observer::onError,

                observer::onComplete

            );

    }

    private String parseStreamToken(String data) {

        if (data == null || !data.startsWith("data:")) {

            return null;

        }

        String json = data.substring(5).trim();

        if ("[DONE]".equals(json)) {

            return null;

        }

        try {

            JsonNode node = objectMapper.readTree(json);

            return node.path("choices").get(0).path("delta").path("content").asText(null);

        } catch (Exception e) {

            return null;

        }

    }

    private String getApiKey() {

        // 从配置或密钥服务获取API密钥

        return System.getenv("OPENAI_API_KEY");

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

10.3 前端实现

**WebSocket服务**:

class HeartbeatClient {
    constructor(wsClient) {
        this.wsClient = wsClient;
        this.pingInterval = null;
        this.pongTimeout = null;
        this.heartbeatInterval = 30000;  // 30秒
        this.pongTimeoutMs = 5000;        // 5秒无响应认为断开
    }

    start() {
        this.pingInterval = setInterval(() => {
            this.sendPing();
        }, this.heartbeatInterval);
    }

    sendPing() {
        this.wsClient.ping();

        // 启动Pong超时计时器
        this.pongTimeout = setTimeout(() => {
            console.warn('Pong timeout, reconnecting...');
            this.wsClient.disconnect();
            this.wsClient.connect();
        }, this.pongTimeoutMs);
    }

    onPong() {
        // 收到Pong响应,清除超时计时器
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }

    stop() {
        if (this.pingInterval) {
            clearInterval(this.pingInterval);
            this.pingInterval = null;
        }
        if (this.pongTimeout) {
            clearTimeout(this.pongTimeout);
            this.pongTimeout = null;
        }
    }
}

class ChatWebSocketService {

    constructor() {

        this.socket = null;

        this.reconnectAttempts = 0;

        this.maxReconnectAttempts = 10;

        this.listeners = {

            chunk: [],

            done: [],

            error: [],

            connected: [],

            disconnected: []

        };

    }

    connect(token) {

        return new Promise((resolve, reject) => {

            this.socket = new WebSocket(`wss://api.example.com/ws/llm?token=${token}`);

            this.socket.onopen = () => {

                console.log('WebSocket connected');

                this.reconnectAttempts = 0;

                this.listeners.connected.forEach(cb => cb());

                resolve();

            };

            this.socket.onmessage = (event) => {

                const response = JSON.parse(event.data);

                switch (response.type) {

                    case 'chunk':

                        this.listeners.chunk.forEach(cb => cb(response.data));

                        break;

                    case 'done':

                        this.listeners.done.forEach(cb => cb(response.conversationId));

                        break;

                    case 'error':

                        this.listeners.error.forEach(cb => cb(response.data));

                        break;

                }

            };

            this.socket.onerror = (error) => {

                console.error('WebSocket error', error);

                reject(error);

            };

            this.socket.onclose = (event) => {

                console.log('WebSocket closed', event.code, event.reason);

                this.listeners.disconnected.forEach(cb => cb(event));

                this.attemptReconnect(token);

            };

        });

    }

    send(data) {

        if (this.socket && this.socket.readyState === WebSocket.OPEN) {

            this.socket.send(JSON.stringify(data));

        }

    }

    sendChat(content, conversationId = null) {

        this.send({

            type: 'chat',

            content: content,

            conversationId: conversationId

        });

    }

    ping() {

        this.send({ type: 'ping' });

    }

    clear(conversationId) {

        this.send({

            type: 'clear',

            conversationId: conversationId

        });

    }

    on(event, callback) {

        if (this.listeners[event]) {

            this.listeners[event].push(callback);

        }

    }

    removeListener(event, callback) {

        if (this.listeners[event]) {

            this.listeners[event] = this.listeners[event].filter(cb => cb !== callback);

        }

    }

    attemptReconnect(token) {

        if (this.reconnectAttempts >= this.maxReconnectAttempts) {

            console.log('Max reconnection attempts reached');

            return;

        }

        this.reconnectAttempts++;

        const delay = Math.min(1000 * Math.pow(2, this.reconnectAttempts), 30000);

        console.log(`Reconnecting in ${delay}ms...`);

        setTimeout(() => {

            this.connect(token).catch(() => {});

        }, delay);

    }

    disconnect() {

        this.maxReconnectAttempts = 0;

        this.socket && this.socket.close();

    }

}

export const chatService = new ChatWebSocketService();


除了配置参数,我们还需要在代码层面实现连接数量限制:

**Vue组件**:

<template>
  <div class="chat-container">
    <div class="chat-header">
      <h2>LLM Chat</h2>
      <button @click="clearConversation" class="clear-btn">Clear</button>
    </div>

    <div class="messages" ref="messagesContainer">
      <div v-for="(msg, index) in messages" :key="index"
           :class="['message', msg.role]">
        <div class="message-content" v-html="renderContent(msg.content)"></div>
      </div>

      <div v-if="isLoading" class="message assistant loading">
        <span class="cursor"></span>
      </div>
    </div>

    <div class="input-area">
      <textarea
        v-model="inputText"
        @keydown.enter.exact.prevent="sendMessage"
        placeholder="Type your message..."
        rows="3"
      ></textarea>
      <button @click="sendMessage" :disabled="!inputText || isLoading">
        Send
      </button>
    </div>
  </div>
</template>

<script>
import { chatService } from './services/websocket';
import { marked } from 'marked';
import hljs from 'highlight.js';

export default {
  data() {
    return {
      inputText: '',
      messages: [],
      isLoading: false,
      currentConversationId: null
    };
  },

  mounted() {
    // 连接到WebSocket
    const token = this.getAuthToken();
    chatService.connect(token);

    // 设置事件监听
    chatService.on('chunk', this.appendChunk);
    chatService.on('done', this.onComplete);
    chatService.on('error', this.onError);
  },

  beforeUnmount() {
    chatService.disconnect();
  },

  methods: {
    sendMessage() {
      if (!this.inputText.trim() || this.isLoading) return;

      const content = this.inputText.trim();
      this.messages.push({ role: 'user', content });
      this.inputText = '';
      this.scrollToBottom();

      this.isLoading = true;
      chatService.sendChat(content, this.currentConversationId);
    },

    appendChunk(token) {
      if (this.messages.length === 0 ||
          this.messages[this.messages.length - 1].role !== 'assistant') {
        this.messages.push({ role: 'assistant', content: token });
      } else {
        this.messages[this.messages.length - 1].content += token;
      }
      this.scrollToBottom();
    },

    onComplete(conversationId) {
      this.currentConversationId = conversationId;
      this.isLoading = false;

      // 渲染Markdown
      const lastMsg = this.messages[this.messages.length - 1];
      if (lastMsg && lastMsg.role === 'assistant') {
        lastMsg.content = this.renderMarkdown(lastMsg.content);
      }
    },

    onError(error) {
      this.isLoading = false;
      this.messages.push({
        role: 'error',
        content: 'Error: ' + error
      });
    },

    clearConversation() {
      this.messages = [];
      if (this.currentConversationId) {
        chatService.clear(this.currentConversationId);
      }
    },

    renderMarkdown(content) {
      return marked.parse(content, {
        highlight: (code, lang) => {
          if (lang && hljs.getLanguage(lang)) {
            return hljs.highlight(code, { language: lang }).value;
          }
          return code;
        }
      });
    },

    scrollToBottom() {
      this.$nextTick(() => {
        this.$refs.messagesContainer.scrollTop =
          this.$refs.messagesContainer.scrollHeight;
      });
    },

    getAuthToken() {
      // 从localStorage或Vuex获取token
      return localStorage.getItem('auth_token');
    }
  }
};
</script>

<style scoped>
.chat-container {
  display: flex;
  flex-direction: column;
  height: 100vh;
  max-width: 800px;
  margin: 0 auto;
  padding: 20px;
}

.messages {
  flex: 1;
  overflow-y: auto;
  padding: 20px;
  background: #f5f5f5;
  border-radius: 8px;
  margin-bottom: 20px;
}

.message {
  margin-bottom: 16px;
  padding: 12px 16px;
  border-radius: 12px;
  max-width: 80%;
}

.message.user {
  background: #007aff;
  color: white;
  margin-left: auto;
}

.message.assistant {
  background: white;
  color: #333;
}

.message.error {
  background: #ff3b30;
  color: white;
}

.input-area {
  display: flex;
  gap: 12px;
}

.input-area textarea {
  flex: 1;
  padding: 12px;
  border: 1px solid #ddd;
  border-radius: 8px;
  resize: none;
  font-size: 14px;
}

.input-area button {
  padding: 12px 24px;
  background: #007aff;
  color: white;
  border: none;
  border-radius: 8px;
  cursor: pointer;
}

.input-area button:disabled {
  background: #ccc;
  cursor: not-allowed;
}
</style>

<template>

  <div class="chat-container">

    <div class="chat-header">

      <h2>LLM Chat</h2>

      <button @click="clearConversation" class="clear-btn">Clear</button>

    </div>

    <div class="messages" ref="messagesContainer">

      <div v-for="(msg, index) in messages" :key="index"

           :class="['message', msg.role]">

        <div class="message-content" v-html="renderContent(msg.content)"></div>

      </div>

      <div v-if="isLoading" class="message assistant loading">

        <span class="cursor"></span>

      </div>

    </div>

    <div class="input-area">

      <textarea

        v-model="inputText"

        @keydown.enter.exact.prevent="sendMessage"

        placeholder="Type your message..."

        rows="3"

      ></textarea>

      <button @click="sendMessage" :disabled="!inputText || isLoading">

        Send

      </button>

    </div>

  </div>

</template>

<script>

import { chatService } from './services/websocket';

import { marked } from 'marked';

import hljs from 'highlight.js';

export default {

  data() {

    return {

      inputText: '',

      messages: [],

      isLoading: false,

      currentConversationId: null

    };

  },

  mounted() {

    // 连接到WebSocket

    const token = this.getAuthToken();

    chatService.connect(token);

    // 设置事件监听

    chatService.on('chunk', this.appendChunk);

    chatService.on('done', this.onComplete);

    chatService.on('error', this.onError);

  },

  beforeUnmount() {

    chatService.disconnect();

  },

  methods: {

    sendMessage() {

      if (!this.inputText.trim() || this.isLoading) return;

      const content = this.inputText.trim();

      this.messages.push({ role: 'user', content });

      this.inputText = '';

      this.scrollToBottom();

      this.isLoading = true;

      chatService.sendChat(content, this.currentConversationId);

    },

    appendChunk(token) {

      if (this.messages.length === 0 ||

          this.messages[this.messages.length - 1].role !== 'assistant') {

        this.messages.push({ role: 'assistant', content: token });

      } else {

        this.messages[this.messages.length - 1].content += token;

      }

      this.scrollToBottom();

    },

    onComplete(conversationId) {

      this.currentConversationId = conversationId;

      this.isLoading = false;

      // 渲染Markdown

      const lastMsg = this.messages[this.messages.length - 1];

      if (lastMsg && lastMsg.role === 'assistant') {

        lastMsg.content = this.renderMarkdown(lastMsg.content);

      }

    },

    onError(error) {

      this.isLoading = false;

      this.messages.push({

        role: 'error',

        content: 'Error: ' + error

      });

    },

    clearConversation() {

      this.messages = [];

      if (this.currentConversationId) {

        chatService.clear(this.currentConversationId);

      }

    },

    renderMarkdown(content) {

      return marked.parse(content, {

        highlight: (code, lang) => {

          if (lang && hljs.getLanguage(lang)) {

            return hljs.highlight(code, { language: lang }).value;

          }

          return code;

        }

      });

    },

    scrollToBottom() {

      this.$nextTick(() => {

        this.$refs.messagesContainer.scrollTop =

          this.$refs.messagesContainer.scrollHeight;

      });

    },

    getAuthToken() {

      // 从localStorage或Vuex获取token

      return localStorage.getItem('auth_token');

    }

  }

};

</script>

<style scoped>

.chat-container {

  display: flex;

  flex-direction: column;

  height: 100vh;

  max-width: 800px;

  margin: 0 auto;

  padding: 20px;

}

.messages {

  flex: 1;

  overflow-y: auto;

  padding: 20px;

  background: #f5f5f5;

  border-radius: 8px;

  margin-bottom: 20px;

}

.message {

  margin-bottom: 16px;

  padding: 12px 16px;

  border-radius: 12px;

  max-width: 80%;

}

.message.user {

  background: #007aff;

  color: white;

  margin-left: auto;

}

.message.assistant {

  background: white;

  color: #333;

}

.message.error {

  background: #ff3b30;

  color: white;

}

.input-area {

  display: flex;

  gap: 12px;

}

.input-area textarea {

  flex: 1;

  padding: 12px;

  border: 1px solid #ddd;

  border-radius: 8px;

  resize: none;

  font-size: 14px;

}

.input-area button {

  padding: 12px 24px;

  background: #007aff;

  color: white;

  border: none;

  border-radius: 8px;

  cursor: pointer;

}

.input-area button:disabled {

  background: #ccc;

  cursor: not-allowed;

}

</style>


除了配置参数,我们还需要在代码层面实现连接数量限制:

10.4 部署与运维

**Docker部署**:

FROM eclipse-temurin:17-jre
WORKDIR /app
COPY target/llm-chat-backend.jar app.jar
EXPOSE 8080
ENTRYPOINT ["java", "-jar", "app.jar"]

FROM eclipse-temurin:17-jre

WORKDIR /app

COPY target/llm-chat-backend.jar app.jar

EXPOSE 8080

ENTRYPOINT ["java", "-jar", "app.jar"]


除了配置参数,我们还需要在代码层面实现连接数量限制:

**docker-compose.yml**:

version: '3.8'
services:
  backend:
    build: ./backend
    ports:
      - "8080:8080"
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - REDIS_HOST=redis
      - REDIS_PORT=6379
    depends_on:
      - redis

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"

  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - backend

version: '3.8'

services:

  backend:

    build: ./backend

    ports:

      - "8080:8080"

    environment:

      - OPENAI_API_KEY=${OPENAI_API_KEY}

      - REDIS_HOST=redis

      - REDIS_PORT=6379

    depends_on:

      - redis

  redis:

    image: redis:7-alpine

    ports:

      - "6379:6379"

  nginx:

    image: nginx:alpine

    ports:

      - "80:80"

    volumes:

      - ./nginx.conf:/etc/nginx/nginx.conf

    depends_on:

      - backend


除了配置参数,我们还需要在代码层面实现连接数量限制:

**Nginx配置**(支持WebSocket代理):

upstream backend {
    server backend:8080;
}

map $http_upgrade $connection_upgrade {
    default upgrade;
    '' close;
}

server {
    listen 80;

    location / {
        proxy_pass http://frontend;
    }

    location /ws/ {
        proxy_pass http://backend;
        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection $connection_upgrade;
        proxy_set_header Host $host;
        proxy_read_timeout 86400;
    }
}

upstream backend {

    server backend:8080;

}

map $http_upgrade $connection_upgrade {

    default upgrade;

    '' close;

}

server {

    listen 80;

    location / {

        proxy_pass http://frontend;

    }

    location /ws/ {

        proxy_pass http://backend;

        proxy_http_version 1.1;

        proxy_set_header Upgrade $http_upgrade;

        proxy_set_header Connection $connection_upgrade;

        proxy_set_header Host $host;

        proxy_read_timeout 86400;

    }

}


除了配置参数,我们还需要在代码层面实现连接数量限制:

总结

本文深入探讨了基于WebSocket协议整合大模型服务的核心技术方案。从协议原理到实现细节,从安全策略到运维部署,我们详细分析了构建高性能长连接交互服务的各个环节。

WebSocket作为一种成熟的全双工通信协议,为大模型应用提供了理想的实时交互能力。通过Spring Boot的WebSocket支持,开发者可以快速构建流式响应、多轮对话、即时反馈等高级功能。同时,配合适当的前端框架和安全策略,可以为用户提供接近原生应用体验的Web端智能对话系统。

在实际项目中,还需要根据具体需求进行技术选型和架构优化。例如,对于需要极高并发的场景,可能需要考虑使用Netty等高性能网络框架;对于需要严格数据隔离的多租户系统,可能需要加强会话管理和权限控制;对于需要全球化部署的系统,时区和语言的处理也是不可忽视的因素。

希望本文能为开发者提供有价值的参考,帮助大家在大模型时代构建更加出色的实时交互应用。

---

**参考资源**:

- RFC 6455: The WebSocket Protocol

- Spring WebSocket Documentation

- MDN Web Docs: WebSocket API

- OpenAI API Documentation

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐