FastAPI + LangChain构建RAG知识库

文章信息

  • 标题:FastAPI + LangChain构建RAG知识库
  • 字数:4500字
  • 预估阅读时间:20分钟
  • 难度:⭐⭐⭐⭐☆

一、为什么RAG是AI应用的必经之路

近年来 RAG(Retrieval-Augmented Generation)已经成为企业AI应用的事实标准。原因很简单:大模型的知识有截止日期,且无法回答你公司内部的私有数据。RAG通过"检索+生成"的混合架构,让模型在最新、最私有、最专业的信息上作答。

FastAPI + LangChain是当前构建RAG系统最主流的技术组合,原因是:

  1. Python优先:LangChain本身是Python生态,两者天然契合
  2. 异步原生:FastAPI异步特性与LangChain的LCEL(LangChain Expression Language)配合流畅
  3. 可观测性强:中间结果(检索到的文档、得分、生成过程)都能通过FastAPI暴露为API
  4. 生产就绪:流式输出(SSE)、并发控制、错误处理都能用FastAPI原生能力实现

二、RAG核心原理

用户问题 → 改写/向量化 → 向量数据库检索 → 上下文组装 → LLM生成 → 返回答案

完整RAG流程分为索引阶段检索阶段

索引阶段(离线)

文档 → 加载(Loader) → 切分(Chunker) → 向量化(Embedding) → 写入向量数据库

检索阶段(在线)

用户问题 → 向量化 → 相似度检索 → Top-K文档 → 组装Prompt → LLM生成 → 回答

三、向量数据库选型

数据库 适用场景 优点 缺点
Chroma 原型/小型应用 轻量、嵌入式、一键启动 大规模生产场景下性能和扩展性有限,不支持分布式
FAISS 中型数据、单机部署 Meta开源、检索速度快、支持量化和聚类 无云原生、运维复杂
Qdrant 生产级应用 支持过滤条件、分布式、云原生API 资源占用较高
Milvus 超大规模数据 亿级向量支持、成熟社区 部署运维复杂
Pinecone 云原生全托管 免运维、自动扩容 需要付费、供应商锁定

选型建议

  • 个人项目/原型 → Chroma
  • 中型企业内部系统 → Qdrant(Docker一键部署)
  • 大型企业/超大规模 → Milvus或Pinecone

本文演示用Chroma(快速原型)和Qdrant(生产级)两种方案。

四、完整项目结构

fastapi-rag/
├── app/
│   ├── __init__.py
│   ├── main.py
│   ├── config.py
│   ├── models.py
│   ├── services/
│   │   ├── __init__.py
│   │   ├── document_loader.py    # 文档加载
│   │   ├── text_splitter.py      # 文本切分
│   │   ├── embedder.py           # 向量化
│   │   ├── vector_store.py       # 向量库管理
│   │   └── rag_chain.py         # RAG链
│   └── routes/
│       ├── __init__.py
│       ├── ingest.py             # 文档入库路由
│       └── query.py              # 查询路由
├── data/
│   └── sample.pdf                # 测试文档
├── .env
├── requirements.txt
└── pyproject.toml

五、环境安装

mkdir fastapi-rag && cd fastapi-rag
uv venv --python 3.12
source .venv/bin/activate

# 核心依赖
uv pip install \
    fastapi uvicorn \
    langchain langchain-core langchain-community \
    langchain-deepseek \
    chromadb qdrant-client \
    unstructured[pdf] python-dotenv \
    pydantic-settings \
    sse-starlette tiktoken \
    httpx aiofiles

六、配置管理

# app/config.py
from pydantic_settings import BaseSettings
from functools import lru_cache
from typing import Literal

class Settings(BaseSettings):
    # DeepSeek Embedding配置
    deepseek_api_key: str = "sk-your-key"
    deepseek_base_url: str = "https://api.deepseek.com"
    embed_model: str = "deepseek-embed"
    
    # LLM配置
    llm_model: str = "deepseek-chat"
    llm_temperature: float = 0.7
    llm_max_tokens: int = 2048
    
    # 向量库配置
    vector_store_type: Literal["chroma", "qdrant"] = "chroma"
    chroma_persist_dir: str = "./data/chroma_db"
    qdrant_url: str = "http://localhost:6333"
    qdrant_collection: str = "rag_knowledge_base"
    qdrant_vector_size: int = 1536  # DeepSeek Embedding维度
    
    # RAG配置
    chunk_size: int = 500
    chunk_overlap: int = 50
    top_k: int = 5
    similarity_threshold: float = 0.6
    
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

@lru_cache
def get_settings() -> Settings:
    return Settings()

七、文档加载与切分

7.1 文档加载服务

LangChain提供了丰富的文档加载器,支持PDF、Markdown、HTML、TXT、CSV等多种格式:

# app/services/document_loader.py
from langchain_community.document_loaders import (
    PyPDFLoader,
    UnstructuredMarkdownLoader,
    TextLoader,
    CSVLoader,
)
from langchain_core.documents import Document
from pathlib import Path
from typing import Optional

class DocumentLoaderService:
    """统一文档加载服务,支持多种格式"""
    
    LOADERS = {
        ".pdf": PyPDFLoader,
        ".md": UnstructuredMarkdownLoader,
        ".txt": TextLoader,
        ".csv": CSVLoader,
    }
    
    @classmethod
    def get_loader(cls, file_path: str) -> Optional[object]:
        """根据文件扩展名返回对应的加载器"""
        ext = Path(file_path).suffix.lower()
        loader_cls = cls.LOADERS.get(ext)
        if not loader_cls:
            return None
        return loader_cls(file_path)
    
    @classmethod
    def load_file(cls, file_path: str) -> list[Document]:
        """加载单个文件"""
        loader = cls.get_loader(file_path)
        if not loader:
            raise ValueError(f"Unsupported file type: {file_path}")
        
        # PyPDFLoader需要单独处理编码
        if isinstance(loader, PyPDFLoader):
            docs = loader.load_and_split()
        else:
            docs = loader.load()
        
        return docs
    
    @classmethod
    def load_directory(cls, directory: str, glob_pattern: str = "**/*") -> list[Document]:
        """批量加载目录下的所有文档"""
        all_docs = []
        for file_path in Path(directory).glob(glob_pattern):
            try:
                docs = cls.load_file(str(file_path))
                for doc in docs:
                    doc.metadata["source_file"] = str(file_path)
                all_docs.extend(docs)
            except Exception as e:
                print(f"Failed to load {file_path}: {e}")
        return all_docs

7.2 文本切分策略

切分是RAG质量最关键的一步。切得太长,上下文窗口浪费;切得太短,语义断裂。推荐使用RecursiveCharacterTextSplitter

# app/services/text_splitter.py
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Optional

class TextSplitterService:
    """智能文本切分服务"""
    
    def __init__(
        self,
        chunk_size: int = 500,
        chunk_overlap: int = 50,
        separators: Optional[list[str]] = None,
    ):
        if separators is None:
            separators = [
                "\n\n",   # 优先按段落切分
                "\n",     # 然后按句子切分
                "。",     # 中文句号
                "!",     # 中文感叹号
                "?",     # 中文问号
                " ",      # 空格
                "",      # 最后按字符切分
            ]
        
        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=separators,
            length_function=self._len_with_chinese,
        )
    
    @staticmethod
    def _len_with_chinese(text: str) -> int:
        """中文字符算1个字符,英文按字符计算"""
        return len(text)
    
    def split_documents(self, documents: list[Document]) -> list[Document]:
        """切分文档列表"""
        if not documents:
            return []
        chunks = self.splitter.split_documents(documents)
        
        # 为每个chunk添加序号元数据
        for i, chunk in enumerate(chunks):
            chunk.metadata["chunk_index"] = i
            chunk.metadata["total_chunks"] = len(chunks)
        
        return chunks

踩坑记录1:中文切分使用英文默认separator会导致"我是一个好人"被切成"我是一个"+“好人”,语义割裂。务必自定义中文separator列表。

八、Embedding服务(DeepSeek)

# app/services/embedder.py
from langchain_core.embeddings import Embeddings
from langchain_deepseek import DeepSeekEmbeddings
from app.config import get_settings
from functools import lru_cache

class DeepSeekEmbedder(Embeddings):
    """DeepSeek Embedding封装"""
    
    def __init__(self, api_key: str, model: str = "deepseek-embed"):
        self._embedder = DeepSeekEmbeddings(
            deepseek_api_key=api_key,
            model=model,
        )
    
    def embed_query(self, text: str) -> list[float]:
        return self._embedder.embed_query(text)
    
    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        return self._embedder.embed_documents(texts)

@lru_cache
def get_embedder() -> DeepSeekEmbedder:
    settings = get_settings()
    return DeepSeekEmbedder(
        api_key=settings.deepseek_api_key,
        model=settings.embed_model,
    )

九、向量数据库管理

9.1 Chroma方案(轻量级)

# app/services/vector_store.py (Chroma部分)
from langchain_core.vectorstores import VectorStore
from langchain_core.documents import Document
from langchain_chroma import Chroma
from app.services.embedder import get_embedder
from app.config import get_settings

class ChromaVectorStore:
    """Chroma向量库管理"""
    
    def __init__(self, collection_name: str = "documents"):
        settings = get_settings()
        self._embedder = get_embedder()
        self._persist_dir = settings.chroma_persist_dir
        self._collection_name = collection_name
        self._vectorstore: VectorStore | None = None
    
    def _get_or_create(self) -> VectorStore:
        if self._vectorstore is None:
            self._vectorstore = Chroma(
                collection_name=self._collection_name,
                embedding_function=self._embedder,
                persist_directory=self._persist_dir,
            )
        return self._vectorstore
    
    def add_documents(self, documents: list[Document]) -> list[str]:
        """入库文档"""
        vs = self._get_or_create()
        return vs.add_documents(documents)
    
    def similarity_search(
        self,
        query: str,
        k: int = 5,
        filter: dict | None = None,
    ) -> list[Document]:
        """相似度检索"""
        vs = self._get_or_create()
        return vs.similarity_search(query, k=k, filter=filter)
    
    def similarity_search_with_score(
        self,
        query: str,
        k: int = 5,
        filter: dict | None = None,
    ) -> list[tuple[Document, float]]:
        """带得分的相似度检索"""
        vs = self._get_or_create()
        return vs.similarity_search_with_score(query, k=k, filter=filter)
    
    def delete_collection(self) -> None:
        """清空集合"""
        vs = self._get_or_create()
        vs.delete_collection()
        self._vectorstore = None

9.2 Qdrant方案(生产级)

# app/services/vector_store.py (Qdrant部分)
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from app.config import get_settings

class QdrantVectorStore:
    """Qdrant向量库管理"""
    
    def __init__(
        self,
        url: str = "http://localhost:6333",
        collection_name: str = "rag_knowledge_base",
        vector_size: int = 1536,
    ):
        settings = get_settings()
        self._embedder = get_embedder()
        self._url = settings.qdrant_url
        self._collection = settings.qdrant_collection
        self._vector_size = settings.qdrant_vector_size
        self._client = QdrantClient(url=self._url)
        self._vectorstore: VectorStore | None = None
    
    def _ensure_collection(self) -> None:
        """确保集合存在,不存在则创建"""
        collections = self._client.get_collections().collections
        collection_names = [c.name for c in collections]
        
        if self._collection not in collection_names:
            from qdrant_client.http import models
            self._client.create_collection(
                collection_name=self._collection,
                vectors_config=models.VectorParams(
                    size=self._vector_size,
                    distance=models.Distance.COSINE,
                ),
            )
            print(f"Created Qdrant collection: {self._collection}")
    
    def _get_or_create(self) -> VectorStore:
        if self._vectorstore is None:
            self._ensure_collection()
            self._vectorstore = QdrantVectorStore(
                client=self._client,
                collection_name=self._collection,
                embedding=self._embedder,
            )
        return self._vectorstore
    
    def add_documents(self, documents: list[Document]) -> list[str]:
        vs = self._get_or_create()
        return vs.add_documents(documents)
    
    def similarity_search_with_score(
        self,
        query: str,
        k: int = 5,
        filter: dict | None = None,
    ) -> list[tuple[Document, float]]:
        vs = self._get_or_create()
        return vs.similarity_search_with_score(query, k=k, filter=filter)

十、RAG检索链

RAG 链的核心是:检索相关文档 → 格式化为上下文 → 拼入 Prompt → 调用 LLM 生成回答。先定义 Prompt 模板和数据类型:

from langchain_core.prompts import PromptTemplate
from typing import TypedDict

RAG_PROMPT_TEMPLATE = """你是一个专业的知识库助手。请基于以下检索到的参考资料回答用户问题。

**参考资料**:
{context}

**用户问题**: {question}

**回答要求**:
1. 必须基于参考资料回答,不要编造信息
2. 如果参考资料中没有相关信息,请明确告知用户"没有找到相关内容"
3. 回答要简洁、有条理,必要时可列出要点
4. 在回答末尾注明参考来源

回答:"""

RAG_PROMPT = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)


class RetrievedDocs(TypedDict):
    content: str
    source: str
    score: float
    chunk_index: int

然后实现 RAGChain 类——_build_chain 使用 LangChain LCEL(LangChain Expression Language)把检索器、Prompt、LLM 串成一条链:

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_deepseek import ChatDeepSeek
from app.config import get_settings


class RAGChain:
    def __init__(self):
        settings = get_settings()
        self._llm = ChatDeepSeek(
            model=settings.llm_model, temperature=settings.llm_temperature,
            max_tokens=settings.llm_max_tokens,
            api_key=settings.deepseek_api_key, base_url=settings.deepseek_base_url,
        )
        self._similarity_threshold = settings.similarity_threshold

    def _format_docs(self, docs: list[Document]) -> str:
        formatted = []
        for i, doc in enumerate(docs):
            source = doc.metadata.get("source_file", "unknown")
            formatted.append(f"[文档{i+1}] 来源: {source}\n{doc.page_content}")
        return "\n\n---\n\n".join(formatted)

    def _build_chain(self, retriever):
        return (
            {"context": retriever | self._format_docs, "question": RunnablePassthrough()}
            | RAG_PROMPT
            | self._llm
            | StrOutputParser()
        )

query 方法先检索文档,过滤低分结果,再调用链生成回答:

    def query(self, question: str, retriever) -> tuple[str, list[RetrievedDocs]]:
        docs_with_scores = retriever.invoke(question)
        filtered_docs = [doc for doc, score in docs_with_scores if score >= self._similarity_threshold]

        if not filtered_docs:
            return "没有找到与您问题相关的参考资料。", []

        chain = self._build_chain(retriever)
        answer = chain.invoke(question)

        retrieved = [
            {"content": doc.page_content[:100] + "...", "source": doc.metadata.get("source_file", "unknown"),
             "score": score, "chunk_index": doc.metadata.get("chunk_index", 0)}
            for doc, score in docs_with_scores
        ]
        return answer, retrieved

LCEL 链的执行顺序retriever | self._format_docs 把检索结果格式化为字符串,然后和 question 一起传入 Prompt,再传给 LLM,最后通过 StrOutputParser 提取纯文本。

十一、FastAPI路由封装

11.1 数据模型

from pydantic import BaseModel, Field
from typing import Optional


class IngestRequest(BaseModel):
    file_path: str = Field(description="要入库的文档路径")
    collection_name: Optional[str] = Field(default=None, description="集合名称")


class IngestResponse(BaseModel):
    status: str
    chunks_count: int
    document_id: str


class QueryRequest(BaseModel):
    question: str = Field(min_length=1, max_length=1000, description="用户问题")
    top_k: Optional[int] = Field(default=5, ge=1, le=20, description="检索数量")
    stream: bool = Field(default=False, description="是否启用流式输出")


class RetrievedDoc(BaseModel):
    content: str
    source: str
    score: float
    chunk_index: int


class QueryResponse(BaseModel):
    answer: str
    retrieved_docs: list[RetrievedDoc]
    total_retrieved: int

11.2 文档入库路由

入库流程:加载文档 → 切分 → 向量化 → 存入向量数据库。

from fastapi import APIRouter, HTTPException
from app.models import IngestRequest, IngestResponse
from app.services.document_loader import DocumentLoaderService
from app.services.text_splitter import TextSplitterService
from app.services.vector_store import ChromaVectorStore
from app.config import get_settings
import uuid

router = APIRouter(prefix="/ingest", tags=["文档入库"])


@router.post("/", response_model=IngestResponse)
async def ingest_document(request: IngestRequest):
    try:
        settings = get_settings()
        docs = DocumentLoaderService.load_file(request.file_path)
        if not docs:
            raise HTTPException(status_code=400, detail="文档加载失败,无内容")

        splitter = TextSplitterService(chunk_size=settings.chunk_size, chunk_overlap=settings.chunk_overlap)
        chunks = splitter.split_documents(docs)

        vector_store = ChromaVectorStore(collection_name=request.collection_name or "default")
        doc_ids = vector_store.add_documents(chunks)

        return IngestResponse(
            status="success", chunks_count=len(chunks),
            document_id=doc_ids[0] if doc_ids else str(uuid.uuid4()),
        )
    except FileNotFoundError:
        raise HTTPException(status_code=404, detail=f"文件不存在: {request.file_path}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"入库失败: {str(e)}")

11.3 查询路由

查询路由有两个接口:同步查询返回完整结果,流式查询按句子逐块返回。

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from app.models import QueryRequest, QueryResponse, RetrievedDoc
from app.services.vector_store import ChromaVectorStore
from app.services.rag_chain import RAGChain
import asyncio

router = APIRouter(prefix="/query", tags=["RAG查询"])


@router.post("/", response_model=QueryResponse)
async def query_knowledge_base(request: QueryRequest):
    vector_store = ChromaVectorStore()
    rag = RAGChain()
    answer, retrieved_docs = rag.query(request.question, vector_store.similarity_search_with_score)

    return QueryResponse(
        answer=answer,
        retrieved_docs=[RetrievedDoc(**doc) for doc in retrieved_docs[:request.top_k]],
        total_retrieved=len(retrieved_docs),
    )


@router.post("/stream")
async def query_stream(request: QueryRequest):
    async def event_generator():
        try:
            vector_store = ChromaVectorStore()
            rag = RAGChain()
            answer, retrieved = rag.query(request.question, vector_store.similarity_search_with_score)

            yield f"event: retrieved\ndata: {retrieved}\n\n"

            for sent in answer.split("。"):
                if sent.strip():
                    yield f"event: token\ndata: {sent}。\n\n"
                await asyncio.sleep(0.05)
            yield "event: done\ndata: \n\n"
        except Exception as e:
            yield f"event: error\ndata: {str(e)}\n\n"

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
    )

11.4 应用入口

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routes import ingest, query

app = FastAPI(title="RAG知识库API", version="1.0.0")

app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
app.include_router(ingest.router)
app.include_router(query.router)


@app.get("/health")
async def health():
    return {"status": "ok"}

启动命令:uvicorn app.main:app --reload --host 0.0.0.0 --port 8000

十二、性能优化

12.1 批量入库

# 避免单条入库,每次批量处理
BATCH_SIZE = 100

def batch_add_documents(vector_store, chunks: list[Document]):
    for i in range(0, len(chunks), BATCH_SIZE):
        batch = chunks[i : i + BATCH_SIZE]
        vector_store.add_documents(batch)
        print(f"Progress: {min(i + BATCH_SIZE, len(chunks))}/{len(chunks)}")

12.2 异步并发检索

# app/services/vector_store.py (异步版本)
async def async_similarity_search(
    queries: list[str],
    k: int = 5,
) -> list[list[tuple[Document, float]]]:
    """并发执行多个查询"""
    import asyncio
    
    async def single_query(q: str):
        return vector_store.similarity_search_with_score(q, k=k)
    
    tasks = [single_query(q) for q in queries]
    results = await asyncio.gather(*tasks)
    return results

12.3 Embedding缓存

# 使用缓存避免重复计算相同query的embedding
from functools import lru_cache

@lru_cache(maxsize=1024)
def cached_embed_query(text: str) -> tuple[str, list[float]]:
    """缓存Embedding结果,text作为cache key"""
    return text, embedder.embed_query(text)

踩坑记录2:Embedding计算是RAG延迟的主要来源。生产环境务必做缓存,特别是高频重复查询同一个问题。

12.4 Chroma批量操作

# Chroma原生支持批量add,避免逐条插入
vectorstore = Chroma(...)
vectorstore.add_documents(documents)  # 一次性批量入库,显著提升吞吐量

十三、踩坑记录总结

现象 解决方案
中文切分语义割裂 切分后的chunk语义不完整 自定义separator列表,加入中文标点
Chroma入库极慢 大量文档入库时卡死 使用批量add,避免循环单条入库
Embedding重复计算 相同query重复计算 使用LRU缓存(@lru_cache)
流式输出被nginx截断 SSE响应被buffer截断 nginx配置proxy_buffering off或去掉nginx直接代理
Qdrant连接超时 容器未启动就访问 先轮询检查Qdrant健康状态再操作
相似度分数不可信 不同模型的score范围不同 设relative_threshold而非absolute threshold
向量维度不匹配 Chroma报错dimension mismatch 确保Embedding模型输出维度与Chroma collection配置一致

十四、总结

本文完整构建了一套基于FastAPI + LangChain的RAG知识库系统,覆盖了:

  1. 文档加载:支持PDF、Markdown、TXT等多种格式
  2. 智能切分:RecursiveCharacterTextSplitter + 中文优化
  3. 向量检索:Chroma(原型)和Qdrant(生产)两种方案
  4. RAG生成链:DeepSeek + 精选Prompt + 检索结果注入
  5. API封装:同步查询 + 流式SSE双模式
  6. 性能优化:批量入库、缓存、异步并发

RAG的核心不在于技术实现,而在于文档质量切分策略。再好的检索系统也拯救不了差的文档。入库前务必清洗文档格式、去除无关噪音内容。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐