前言

前一篇实现了文档处理 Pipeline,上传的文档已经变成向量存入了 Qdrant。今天实现面向用户的问答接口——用户提问,从知识库检索相关内容,交给 LLM 生成流式回答。

这是整个产品面向用户的最终功能。

1. 问答整体流程

用户提问
    │
    ▼
1. Query Embedding(问题向量化)
    │
    ▼
2. Qdrant 向量检索(Top-10)
    │
    ▼
3. Cross-Encoder Re-rank(取 Top-3)
    │
    ▼
4. 构建 RAG Prompt
    │
    ▼
5. LLM 流式生成回答
    │
    ▼
6. 保存对话记录 + 引用来源
    │
    ▼
前端流式渲染

2. 检索服务

# backend/app/services/retriever.py
from typing import List, Dict, Any, Optional
from app.services.vector_store import vector_store
from app.services.embedding import embedding_service
import logging

logger = logging.getLogger(__name__)


class Retriever:
    """检索器——将用户问题转为向量并检索相关文档片段。"""

    def __init__(self, kb_id: str):
        self.kb_id = kb_id

    async def retrieve(
        self,
        query: str,
        top_k: int = 10,
        score_threshold: float = 0.3,
    ) -> List[Dict[str, Any]]:
        """检索最相关的文档片段。"""
        # 1. Query 向量化
        query_vector = embedding_service.encode_query(query)

        # 2. Qdrant 检索
        results = await vector_store.search(
            kb_id=self.kb_id,
            query_vector=query_vector,
            top_k=top_k,
            score_threshold=score_threshold,
        )

        logger.info(f"Retrieved {len(results)} chunks for query: {query[:50]}")
        return results


class ReRanker:
    """Cross-Encoder 重排序器。"""

    def __init__(self):
        self._model = None

    def _load(self):
        if self._model is None:
            from sentence_transformers import CrossEncoder
            self._model = CrossEncoder("BAAI/bge-reranker-v2-m3", device="cpu")
            logger.info("Loaded reranker model")

    def rerank(
        self, query: str, documents: List[Dict[str, Any]], top_k: int = 3
    ) -> List[Dict[str, Any]]:
        """对检索结果进行重排序。"""
        if not documents:
            return []

        self._load()
        pairs = [[query, doc["text"][:512]] for doc in documents]
        scores = self._model.predict(pairs)

        for doc, score in zip(documents, scores):
            doc["rerank_score"] = float(score)

        documents.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
        return documents[:top_k]


retriever = Retriever
reranker = ReRanker()

3. 问答 API(流式输出)

3.1 Schema

# backend/app/schemas/chat.py
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime


class ChatRequest(BaseModel):
    kb_id: str
    message: str
    conversation_id: Optional[str] = None


class Citation(BaseModel):
    source: str
    text: str
    score: float


class MessageResponse(BaseModel):
    id: str
    conversation_id: str
    content: str
    role: str
    citations: List[Citation] = []
    created_at: datetime


class ConversationResponse(BaseModel):
    id: str
    title: str
    message_count: int
    created_at: datetime
    updated_at: datetime

3.2 对话 Service

# backend/app/services/chat_service.py
import uuid
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from uuid import UUID
from typing import AsyncGenerator

from app.models.conversation import Conversation, Message, Citation
from app.services.retriever import retriever, reranker
from app.config import settings
from openai import AsyncOpenAI


class ChatService:
    """对话服务——处理 RAG 问答。"""

    def __init__(self):
        self.llm_client = AsyncOpenAI(
            api_key=settings.LLM_API_KEY,
            base_url=settings.LLM_BASE_URL,
        )

        self._rag_prompt = """你是 KNow 知识库助手,请基于以下检索到的资料回答用户的问题。

## 规则
1. 只基于检索资料回答,不要编造信息
2. 如果资料不足以回答问题,请直接说"资料中没有相关信息"
3. 在回答中引用资料来源,用 [1]、[2] 标注
4. 回答要简洁、准确、有条理
5. 使用中文回答

## 检索资料

{context}

## 用户问题

{question}
"""

    async def create_conversation(
        self, db: AsyncSession, user_id: str, kb_id: str, title: str = "新对话"
    ) -> Conversation:
        """创建新对话。"""
        conv = Conversation(
            user_id=UUID(user_id),
            knowledge_base_id=UUID(kb_id),
            title=title,
        )
        db.add(conv)
        await db.commit()
        await db.refresh(conv)
        return conv

    async def get_conversation(
        self, db: AsyncSession, conv_id: str, user_id: str
    ) -> Conversation:
        """获取对话。"""
        result = await db.execute(
            select(Conversation).where(
                Conversation.id == UUID(conv_id),
                Conversation.user_id == UUID(user_id),
            )
        )
        conv = result.scalar_one_or_none()
        if not conv:
            raise ValueError("对话不存在")
        return conv

    async def chat_stream(
        self,
        db: AsyncSession,
        user_id: str,
        kb_id: str,
        message: str,
        conv_id: Optional[str] = None,
    ) -> AsyncGenerator[str, None]:
        """流式问答。"""
        # 1. 获取或创建对话
        if conv_id:
            conv = await self.get_conversation(db, conv_id, user_id)
        else:
            conv = await self.create_conversation(db, user_id, kb_id)

        # 2. 保存用户消息
        user_msg = Message(
            conversation_id=conv.id,
            role="user",
            content=message,
        )
        db.add(user_msg)
        await db.commit()

        # 3. 检索
        retriever_instance = retriever(kb_id)
        retrieved = await retriever_instance.retrieve(message, top_k=10)

        # 4. Re-rank
        ranked = reranker.rerank(message, retrieved, top_k=3)

        # 5. 构建 Prompt
        context = "\n\n".join([
            f"[{i+1}] 来源:{r['source']}\n{r['text']}"
            for i, r in enumerate(ranked)
        ])
        system_prompt = self._rag_prompt.format(context=context, question=message)

        # 6. 更新对话标题(如果是第一条消息)
        if conv.message_count == 0:
            conv.title = message[:50] + ("..." if len(message) > 50 else "")

        # 7. LLM 流式生成
        messages = [{"role": "system", "content": system_prompt}]

        # 添加历史消息(最近 6 轮)
        history = await db.execute(
            select(Message).where(
                Message.conversation_id == conv.id
            ).order_by(Message.created_at.desc()).limit(12)
        )
        for msg in reversed(history.scalars().all()):
            messages.append({"role": msg.role, "content": msg.content})

        stream = await self.llm_client.chat.completions.create(
            model=settings.LLM_MODEL,
            messages=messages,
            temperature=0.3,
            max_tokens=2048,
            stream=True,
        )

        # 8. 流式返回并收集完整回答
        full_content = ""
        async for chunk in stream:
            delta = chunk.choices[0].delta if chunk.choices else None
            if delta and delta.content:
                full_content += delta.content
                yield f"data: {delta.content}\n\n"

        # 9. 保存助手消息和引用
        assistant_msg = Message(
            conversation_id=conv.id,
            role="assistant",
            content=full_content,
            tokens_used=len(full_content) // 2,  # 粗略估算
        )
        db.add(assistant_msg)
        await db.commit()
        await db.refresh(assistant_msg)

        # 保存引用来源
        for i, r in enumerate(ranked):
            citation = Citation(
                message_id=assistant_msg.id,
                document_id="",  # 后续补充实际 document_id
                chunk_text=r["text"][:200],
                score=r.get("rerank_score", r.get("score", 0)),
            )
            db.add(citation)

        # 更新对话消息计数
        conv.message_count = (conv.message_count or 0) + 2
        await db.commit()

        yield "data: [DONE]\n\n"

3.3 对话路由

# backend/app/routers/chat.py
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List

from app.database import get_db
from app.services.auth import require_auth
from app.models.user import User
from app.schemas.chat import (
    ChatRequest, MessageResponse,
    ConversationResponse, Citation,
)
from app.services.chat_service import ChatService

router = APIRouter()
chat_service = ChatService()


@router.post("/chat/stream")
async def chat_stream(
    body: ChatRequest,
    user: User = Depends(require_auth),
    db: AsyncSession = Depends(get_db),
):
    """流式问答。"""
    return StreamingResponse(
        chat_service.chat_stream(
            db=db,
            user_id=str(user.id),
            kb_id=body.kb_id,
            message=body.message,
            conv_id=body.conversation_id,
        ),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


@router.get("/conversations", response_model=List[ConversationResponse])
async def list_conversations(
    kb_id: str = Query(""),
    user: User = Depends(require_auth),
    db: AsyncSession = Depends(get_db),
):
    """获取对话列表。"""
    from sqlalchemy import select
    from app.models.conversation import Conversation

    query = select(Conversation).where(Conversation.user_id == user.id)
    if kb_id:
        query = query.where(Conversation.knowledge_base_id == UUID(kb_id))
    query = query.order_by(Conversation.updated_at.desc()).limit(50)

    result = await db.execute(query)
    convs = result.scalars().all()

    return [
        ConversationResponse(
            id=str(c.id),
            title=c.title,
            message_count=c.message_count or 0,
            created_at=c.created_at,
            updated_at=c.updated_at,
        )
        for c in convs
    ]


@router.get("/conversations/{conv_id}/messages", response_model=List[MessageResponse])
async def get_messages(
    conv_id: str,
    user: User = Depends(require_auth),
    db: AsyncSession = Depends(get_db),
):
    """获取对话消息。"""
    from app.models.conversation import Message

    try:
        conv = await chat_service.get_conversation(db, conv_id, str(user.id))
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e))

    result = await db.execute(
        select(Message).where(
            Message.conversation_id == conv.id
        ).order_by(Message.created_at)
    )
    messages = result.scalars().all()

    return [
        MessageResponse(
            id=str(m.id),
            conversation_id=conv_id,
            content=m.content,
            role=m.role,
            citations=[],  # 简化,后续补充
            created_at=m.created_at,
        )
        for m in messages
    ]


@router.delete("/conversations/{conv_id}", status_code=204)
async def delete_conversation(
    conv_id: str,
    user: User = Depends(require_auth),
    db: AsyncSession = Depends(get_db),
):
    """删除对话。"""
    from sqlalchemy import delete
    from app.models.conversation import Conversation

    try:
        conv = await chat_service.get_conversation(db, conv_id, str(user.id))
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e))

    await db.delete(conv)
    await db.commit()

3.4 注册路由

# backend/app/main.py(更新)
from app.routers import chat

app.include_router(chat.router, prefix="/api", tags=["Chat"])

4. 前端对话页面

4.1 API 层

// frontend/src/api/chat.ts
import api from "@/lib/api";

export interface Citation {
  source: string;
  text: string;
  score: number;
}

export interface Message {
  id: string;
  conversation_id: string;
  content: string;
  role: string;
  citations: Citation[];
  created_at: string;
}

export interface Conversation {
  id: string;
  title: string;
  message_count: number;
  created_at: string;
}

export async function listConversations(kbId?: string): Promise<Conversation[]> {
  const params = kbId ? { kb_id: kbId } : {};
  const { data } = await api.get("/conversations", { params });
  return data;
}

export async function getMessages(convId: string): Promise<Message[]> {
  const { data } = await api.get(`/conversations/${convId}/messages`);
  return data;
}

export async function deleteConversation(convId: string): Promise<void> {
  await api.delete(`/conversations/${convId}`);
}

4.2 流式对话 Hook

// frontend/src/hooks/useChat.ts
import { useState, useRef, useCallback } from "react";
import api from "@/lib/api";

interface ChatMessage {
  id: string;
  role: "user" | "assistant";
  content: string;
  citations?: any[];
}

export function useChat(kbId: string) {
  const [messages, setMessages] = useState<ChatMessage[]>([]);
  const [isLoading, setIsLoading] = useState(false);
  const [convId, setConvId] = useState<string | undefined>();
  const abortRef = useRef<AbortController | null>(null);

  const sendMessage = useCallback(async (content: string) => {
    // 添加用户消息
    const userMsg: ChatMessage = {
      id: Date.now().toString(),
      role: "user",
      content,
    };
    setMessages((prev) => [...prev, userMsg]);
    setIsLoading(true);

    // 添加空的助手消息占位
    const assistantId = (Date.now() + 1).toString();
    setMessages((prev) => [...prev, { id: assistantId, role: "assistant", content: "" }]);

    try {
      abortRef.current = new AbortController();
      const token = localStorage.getItem("token");

      const response = await fetch("/api/chat/stream", {
        method: "POST",
        headers: {
          "Content-Type": "application/json",
          Authorization: `Bearer ${token}`,
        },
        body: JSON.stringify({
          kb_id: kbId,
          message: content,
          conversation_id: convId,
        }),
        signal: abortRef.current.signal,
      });

      const reader = response.body!.getReader();
      const decoder = new TextDecoder();
      let buffer = "";
      let fullContent = "";

      while (true) {
        const { done, value } = await reader.read();
        if (done) break;

        buffer += decoder.decode(value, { stream: true });
        const lines = buffer.split("\n");
        buffer = lines.pop() || "";

        for (const line of lines) {
          if (line.startsWith("data: ")) {
            const data = line.slice(6);
            if (data === "[DONE]") continue;
            fullContent += data;
            setMessages((prev) =>
              prev.map((m) =>
                m.id === assistantId ? { ...m, content: fullContent } : m
              )
            );
          }
        }
      }
    } catch (err: any) {
      if (err.name !== "AbortError") {
        setMessages((prev) =>
          prev.map((m) =>
            m.id === assistantId
              ? { ...m, content: "请求失败,请重试" }
              : m
          )
        );
      }
    } finally {
      setIsLoading(false);
    }
  }, [kbId, convId]);

  const stopGeneration = useCallback(() => {
    abortRef.current?.abort();
    setIsLoading(false);
  }, []);

  return {
    messages,
    isLoading,
    sendMessage,
    stopGeneration,
    setConvId,
  };
}

4.3 对话页面

// frontend/src/pages/Chat.tsx
import { useState, useRef, useEffect } from "react";
import { useSearchParams } from "react-router-dom";
import { useChat } from "@/hooks/useChat";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import ReactMarkdown from "react-markdown";

export default function Chat() {
  const [searchParams] = useSearchParams();
  const kbId = searchParams.get("kb") || "";
  const { messages, isLoading, sendMessage, stopGeneration } = useChat(kbId);
  const [input, setInput] = useState("");
  const bottomRef = useRef<HTMLDivElement>(null);
  const inputRef = useRef<HTMLInputElement>(null);

  useEffect(() => {
    bottomRef.current?.scrollIntoView({ behavior: "smooth" });
  }, [messages]);

  const handleSend = () => {
    if (!input.trim() || isLoading || !kbId) return;
    sendMessage(input.trim());
    setInput("");
  };

  return (
    <div className="flex flex-col h-[calc(100vh-4rem)] max-w-4xl mx-auto">
      {!kbId ? (
        <div className="flex-1 flex items-center justify-center">
          <div className="text-center">
            <div className="text-6xl mb-4">💬</div>
            <h2 className="text-xl font-semibold text-gray-600">选择一个知识库开始问答</h2>
            <p className="text-gray-400 mt-2">
              从仪表盘进入知识库详情,点击"开始问答"
            </p>
          </div>
        </div>
      ) : messages.length === 0 ? (
        <div className="flex-1 flex items-center justify-center">
          <div className="text-center max-w-md">
            <h2 className="text-lg font-semibold">开始提问</h2>
            <p className="text-sm text-gray-400 mt-2">
              你可以问关于知识库中文档的任何问题
            </p>
            <div className="mt-4 space-y-2">
              {[
                "这个项目的主要功能是什么?",
                "文档中对系统架构的描述是怎样的?",
                "有哪些关键的技术决策?",
              ].map((q) => (
                <button
                  key={q}
                  className="block w-full text-left px-4 py-2 rounded-lg border hover:bg-gray-50 text-sm"
                  onClick={() => sendMessage(q)}
                >
                  {q}
                </button>
              ))}
            </div>
          </div>
        </div>
      ) : (
        <div className="flex-1 overflow-y-auto px-4 py-4 space-y-4">
          {messages.map((msg) => (
            <div
              key={msg.id}
              className={`flex ${msg.role === "user" ? "justify-end" : "justify-start"}`}
            >
              <div
                className={`max-w-[80%] rounded-xl px-4 py-3 ${
                  msg.role === "user"
                    ? "bg-blue-600 text-white"
                    : "bg-gray-100 text-gray-800"
                }`}
              >
                {msg.role === "assistant" ? (
                  <div className="prose prose-sm max-w-none">
                    <ReactMarkdown>{msg.content || "..."}</ReactMarkdown>
                  </div>
                ) : (
                  <p className="text-sm">{msg.content}</p>
                )}
              </div>
            </div>
          ))}
          <div ref={bottomRef} />
        </div>
      )}

      <div className="border-t p-4">
        <div className="flex gap-2 max-w-4xl mx-auto">
          <Input
            ref={inputRef}
            value={input}
            onChange={(e) => setInput(e.target.value)}
            onKeyDown={(e) => e.key === "Enter" && !e.shiftKey && handleSend()}
            placeholder={kbId ? "输入问题..." : "请先选择知识库"}
            disabled={!kbId || isLoading}
          />
          {isLoading ? (
            <Button variant="outline" onClick={stopGeneration}>
              停止
            </Button>
          ) : (
            <Button onClick={handleSend} disabled={!kbId || !input.trim()}>
              发送
            </Button>
          )}
        </div>
      </div>
    </div>
  );
}

5. 验证

# 1. 创建对话并问答
curl -X POST http://localhost:8000/api/chat/stream \
  -H "Authorization: Bearer <token>" \
  -H "Content-Type: application/json" \
  -d '{"kb_id":"<kb_id>","message":"这个项目的主要功能是什么?"}'

# 响应(SSE 流式)
# data: 这个项目主要...
# data: 功能包括以下...
# data: [DONE]

# 2. 查看对话列表
curl http://localhost:8000/api/conversations \
  -H "Authorization: Bearer <token>"

# 3. 查看历史消息
curl http://localhost:8000/api/conversations/<conv_id>/messages \
  -H "Authorization: Bearer <token>"

总结

今天完成了产品最核心的功能——基于知识库的智能问答:

组件 说明
检索服务 Qdrant 向量检索 + Re-rank 重排序
问答 API 流式 SSE 输出,支持引用来源
对话管理 创建/列表/历史/删除
前端 Hook 流式消息解析 + 加载状态 + 中止
对话界面 Markdown 渲染 + 预设问题 + 自动滚动

现在用户上传文档后,可以直接向知识库提问了。

下一篇我们将开始前端开发——搭建完整的用户界面。


本文是 《AI 全栈开发实战——做一个真正的产品》 系列的第 6 篇。
系列目录:

  1. ✅ 产品定义
  2. ✅ 项目初始化
  3. ✅ 用户系统
  4. ✅ 知识库与文档管理
  5. ✅ 文档处理 Pipeline
  6. ✅ 向量检索与 RAG 问答 ← 你在这里
  7. 📝 前端开发(一)

本文由 Zyentor(智元界) 原创发布


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐