前言

前一篇我们实现了文档上传管理,但上传后的文档还是"死"的——只是存了个文件。今天我们要让文档活起来:解析内容、切分成片段、向量化、存入向量数据库。

这是整个产品的核心引擎——文档处理 Pipeline。

1. Pipeline 整体流程

文档上传 → 触发 Celery 任务
                ↓
          1. 解析文档(PDF→文本)
                ↓
          2. 文本切分(Chunk)
                ↓
          3. Embedding 向量化
                ↓
          4. 存入 Qdrant
                ↓
          5. 更新文档状态 → ready

2. Celery 任务配置

# backend/app/tasks/__init__.py
from celery import Celery
from app.config import settings

celery_app = Celery(
    "know",
    broker=settings.REDIS_URL,
    backend=settings.REDIS_URL,
    task_serializer="json",
    accept_content=["json"],
)

celery_app.conf.task_routes = {
    "process_document": {"queue": "documents"},
}
# backend/app/tasks/document_tasks.py
from app.tasks import celery_app
from app.services.document_processor import DocumentProcessor


@celery_app.task(bind=True, max_retries=3, default_retry_delay=60)
def process_document(self, document_id: str, kb_id: str):
    """处理文档:解析 → 切分 → Embedding → 存入 Qdrant。"""
    from app.database import SyncSession
    from app.models.document import Document

    db = SyncSession()
    try:
        doc = db.query(Document).filter(Document.id == document_id).first()
        if not doc:
            return {"error": "Document not found"}

        # 更新状态为处理中
        doc.status = "processing"
        db.commit()

        # 执行处理
        processor = DocumentProcessor()
        result = processor.process(
            document_id=str(doc.id),
            kb_id=kb_id,
            file_path=doc.file_path,
            filename=doc.filename,
            file_type=doc.file_type,
        )

        # 更新状态
        doc.status = "ready"
        doc.chunk_count = result.get("chunk_count", 0)
        db.commit()

        return result

    except Exception as e:
        doc.status = "failed"
        doc.error_message = str(e)
        db.commit()
        raise self.retry(exc=e)

    finally:
        db.close()

3. 文档解析

# backend/app/services/document_processor.py
import io
import logging
from pathlib import Path
from typing import List, Dict, Any

from app.services.storage import file_storage

logger = logging.getLogger(__name__)


class DocumentParser:
    """文档解析器——支持多种格式。"""

    @staticmethod
    def parse(file_path: str, filename: str, file_type: str) -> str:
        """从 MinIO 下载文件并解析为纯文本。"""
        # 从 MinIO 读取文件
        response = file_storage.get_object(file_path)
        file_data = response.read()
        response.close()

        if file_type == "pdf":
            return DocumentParser._parse_pdf(file_data)
        elif file_type == "txt":
            return file_data.decode("utf-8", errors="ignore")
        elif file_type == "md":
            return file_data.decode("utf-8", errors="ignore")
        elif file_type == "docx":
            return DocumentParser._parse_docx(file_data)
        else:
            raise ValueError(f"不支持的文件类型: {file_type}")

    @staticmethod
    def _parse_pdf(file_data: bytes) -> str:
        """解析 PDF。"""
        try:
            import fitz  # PyMuPDF
        except ImportError:
            raise ImportError("请安装 PyMuPDF: pip install pymupdf")

        doc = fitz.open(stream=file_data, filetype="pdf")
        text = []
        for page_num, page in enumerate(doc):
            page_text = page.get_text()
            if page_text.strip():
                text.append(f"## 第 {page_num + 1} 页\n\n{page_text}")
        doc.close()

        result = "\n\n".join(text)
        if not result.strip():
            # PDF 可能是扫描件,没有文本层
            logger.warning("PDF 文本为空,可能是扫描件")
        return result

    @staticmethod
    def _parse_docx(file_data: bytes) -> str:
        """解析 DOCX。"""
        try:
            from docx import Document as DocxDocument
        except ImportError:
            raise ImportError("请安装 python-docx: pip install python-docx")

        doc = DocxDocument(io.BytesIO(file_data))
        paragraphs = []
        for para in doc.paragraphs:
            if para.text.strip():
                paragraphs.append(para.text)
        return "\n\n".join(paragraphs)

4. 文本切分

# backend/app/services/text_splitter.py
from typing import List, Optional
import re


class RecursiveTextSplitter:
    """递归字符文本切分器。"""

    def __init__(
        self,
        chunk_size: int = 512,
        chunk_overlap: int = 128,
        separators: Optional[List[str]] = None,
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators or ["\n## ", "\n### ", "\n\n", "\n", "。", "!", "?", ",", " ", ""]

    def split_text(self, text: str) -> List[str]:
        """切分文本为 chunks。"""
        return self._split(text, 0)

    def _split(self, text: str, level: int) -> List[str]:
        if level >= len(self.separators):
            return self._split_by_chars(text)

        separator = self.separators[level]

        if not separator:
            return self._split_by_chars(text)

        segments = text.split(separator)
        chunks = []
        for seg in segments:
            seg = seg.strip()
            if not seg:
                continue
            if len(seg) <= self.chunk_size:
                chunks.append(seg)
            else:
                chunks.extend(self._split(seg, level + 1))

        return self._merge(chunks)

    def _merge(self, chunks: List[str]) -> List[str]:
        """合并过小的 chunks 并添加 overlap。"""
        result = []
        for i, chunk in enumerate(chunks):
            if not chunk:
                continue

            # 从上一个 chunk 尾部取 overlap 内容
            if i > 0 and self.chunk_overlap > 0:
                prev = result[-1]
                overlap = prev[-self.chunk_overlap:] if len(prev) > self.chunk_overlap else prev
                chunk = overlap + chunk

            # 如果结果为空或当前 chunk 和上一个合并后不超过 chunk_size
            if not result or len(result[-1]) + len(chunk) < self.chunk_size * 1.5:
                if result:
                    result[-1] = result[-1] + "\n" + chunk
                else:
                    result.append(chunk)
            else:
                result.append(chunk)

        return result

    def _split_by_chars(self, text: str) -> List[str]:
        """按字符数切分(最后的兜底方案)。"""
        chunks = []
        start = 0
        while start < len(text):
            end = min(start + self.chunk_size, len(text))
            if end < len(text):
                cut = text.rfind("\n", start, end)
                if cut <= start:
                    cut = text.rfind("。", start, end)
                    if cut <= start:
                        cut = end
                    else:
                        cut += 1
                else:
                    cut += 1
            else:
                cut = end
            chunks.append(text[start:cut].strip())
            start = max(cut - self.chunk_overlap, start + 1)
        return [c for c in chunks if c]

    def split_with_metadata(self, text: str, source: str) -> List[Dict[str, Any]]:
        """切分并生成带元数据的 chunks。"""
        texts = self.split_text(text)
        chunks = []
        for i, content in enumerate(texts):
            chunks.append({
                "id": f"{source}#chunk{i}",
                "text": content,
                "metadata": {
                    "source": source,
                    "chunk_index": i,
                    "total_chunks": len(texts),
                },
            })
        return chunks

5. Embedding 服务

# backend/app/services/embedding.py
import numpy as np
from typing import List, Optional
from app.config import settings


class EmbeddingService:
    """Embedding 服务——将文本转为向量。"""

    def __init__(self):
        self._model = None
        self._dimension = settings.EMBEDDING_DIMENSION

    def _load_model(self):
        """延迟加载模型。"""
        if self._model is not None:
            return
        from sentence_transformers import SentenceTransformer
        self._model = SentenceTransformer(
            settings.EMBEDDING_MODEL,
            device="cpu",
        )
        logger.info(f"Loaded embedding model: {settings.EMBEDDING_MODEL}")

    def encode(self, texts: List[str]) -> List[List[float]]:
        """批量编码文本为向量。"""
        if not texts:
            return []

        self._load_model()
        embeddings = self._model.encode(
            texts,
            normalize_embeddings=True,
            show_progress_bar=False,
        )
        return embeddings.tolist()

    def encode_query(self, text: str) -> List[float]:
        """编码查询文本。"""
        return self.encode([text])[0]

    @property
    def dimension(self) -> int:
        return self._dimension


embedding_service = EmbeddingService()

6. Qdrant 向量存储

# backend/app/services/vector_store.py
from qdrant_client import QdrantClient
from qdrant_client.http import models
from typing import List, Dict, Any, Optional
from app.config import settings
import logging

logger = logging.getLogger(__name__)


class VectorStore:
    """Qdrant 向量存储服务。"""

    def __init__(self):
        self.client = QdrantClient(
            host=settings.QDRANT_HOST,
            port=settings.QDRANT_PORT,
        )

    def _collection_name(self, kb_id: str) -> str:
        return f"kb_{kb_id.replace('-', '_')}"

    async def ensure_collection(self, kb_id: str, dimension: int):
        """确保知识库的 collection 存在。"""
        name = self._collection_name(kb_id)
        collections = self.client.get_collections().collections
        exists = any(c.name == name for c in collections)

        if not exists:
            self.client.create_collection(
                collection_name=name,
                vectors_config=models.VectorParams(
                    size=dimension,
                    distance=models.Distance.COSINE,
                ),
            )
            logger.info(f"Created collection: {name}")

    async def upsert_chunks(
        self,
        kb_id: str,
        chunks: List[Dict[str, Any]],
        vectors: List[List[float]],
    ):
        """插入或更新文档片段。"""
        name = self._collection_name(kb_id)

        points = []
        for i, (chunk, vector) in enumerate(zip(chunks, vectors)):
            points.append(models.PointStruct(
                id=hash(chunk["id"]) % (2**63),
                vector=vector,
                payload={
                    "chunk_id": chunk["id"],
                    "text": chunk["text"][:2000],  # 限制 payload 大小
                    "source": chunk["metadata"]["source"],
                    "chunk_index": chunk["metadata"]["chunk_index"],
                },
            ))

        self.client.upsert(
            collection_name=name,
            points=points,
        )

    async def search(
        self,
        kb_id: str,
        query_vector: List[float],
        top_k: int = 10,
        score_threshold: float = 0.3,
    ) -> List[Dict]:
        """检索最相似的文档片段。"""
        name = self._collection_name(kb_id)

        try:
            results = self.client.search(
                collection_name=name,
                query_vector=query_vector,
                limit=top_k,
                score_threshold=score_threshold,
            )
        except Exception as e:
            logger.warning(f"Search failed (collection may not exist): {e}")
            return []

        return [
            {
                "chunk_id": r.payload["chunk_id"],
                "text": r.payload["text"],
                "source": r.payload["source"],
                "score": r.score,
            }
            for r in results
        ]

    async def delete_kb_vectors(self, kb_id: str):
        """删除整个知识库的向量。"""
        name = self._collection_name(kb_id)
        try:
            self.client.delete_collection(name)
        except Exception:
            pass

    async def delete_document_vectors(self, kb_id: str, filename: str):
        """删除某个文档的所有向量。"""
        name = self._collection_name(kb_id)
        self.client.delete(
            collection_name=name,
            points_selector=models.Filter(
                must=[
                    models.FieldCondition(
                        key="source",
                        match=models.MatchValue(value=filename),
                    )
                ]
            ),
        )


vector_store = VectorStore()

7. 文档处理器(整合所有步骤)

# backend/app/services/document_processor.py(完整)
import logging
from typing import Dict, Any

from app.services.storage import file_storage
from app.services.text_splitter import RecursiveTextSplitter
from app.services.embedding import embedding_service
from app.services.vector_store import vector_store

logger = logging.getLogger(__name__)


class DocumentProcessor:
    """文档处理器——整合解析、切分、Embedding、存储。"""

    async def process(
        self,
        document_id: str,
        kb_id: str,
        file_path: str,
        filename: str,
        file_type: str,
        chunk_size: int = 512,
        chunk_overlap: int = 128,
    ) -> Dict[str, Any]:
        """处理文档的完整 Pipeline。"""
        logger.info(f"Processing document: {filename} (ID: {document_id})")

        # Step 1: 解析文档
        parser = DocumentParser()
        text = parser.parse(file_path, filename, file_type)
        logger.info(f"Parsed: {len(text)} chars")

        if not text.strip():
            raise ValueError("文档内容为空")

        # Step 2: 文本切分
        splitter = RecursiveTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
        chunks = splitter.split_with_metadata(text, filename)
        logger.info(f"Split into {len(chunks)} chunks")

        # Step 3: Embedding 向量化
        texts = [c["text"] for c in chunks]
        vectors = embedding_service.encode(texts)
        logger.info(f"Generated {len(vectors)} vectors (dim={embedding_service.dimension})")

        # Step 4: 存入 Qdrant
        await vector_store.ensure_collection(kb_id, embedding_service.dimension)
        await vector_store.upsert_chunks(kb_id, chunks, vectors)
        logger.info(f"Stored in Qdrant")

        return {
            "document_id": document_id,
            "chunk_count": len(chunks),
            "char_count": len(text),
        }

8. 上传即处理——触发 Pipeline

修改文档上传接口,上传成功后自动触发 Celery 任务:

# backend/app/routers/documents.py(更新)
from app.tasks.document_tasks import process_document


@router.post("/{kb_id}/documents", response_model=DocumentResponse, status_code=201)
async def upload_document(
    kb_id: str,
    file: UploadFile = File(...),
    user: User = Depends(require_auth),
    db: AsyncSession = Depends(get_db),
):
    """上传文档并触发处理。"""
    file_data = await file.read()
    try:
        doc = await DocumentService.upload(
            db, kb_id, str(user.id), file_data, file.filename
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

    # 异步触发文档处理
    process_document.delay(
        document_id=str(doc.id),
        kb_id=kb_id,
    )

    return _doc_to_response(doc)

9. 前端状态轮询

文档上传后状态为 pendingprocessingready,前端需要轮询直到完成:

// frontend/src/hooks/useDocumentPolling.ts
import { useState, useEffect, useCallback } from "react";
import { listDocuments, Document } from "@/api/knowledgeBase";

export function useDocumentPolling(kbId: string) {
  const [docs, setDocs] = useState<Document[]>([]);
  const [loading, setLoading] = useState(true);
  const [hasProcessing, setHasProcessing] = useState(false);

  const load = useCallback(async () => {
    const res = await listDocuments(kbId);
    setDocs(res.items);
    setLoading(false);

    const processing = res.items.some(
      (d) => d.status === "pending" || d.status === "processing"
    );
    setHasProcessing(processing);
  }, [kbId]);

  // 初始加载
  useEffect(() => {
    load();
  }, [load]);

  // 如果有处理中的文档,每 3 秒轮询一次
  useEffect(() => {
    if (!hasProcessing) return;
    const timer = setInterval(load, 3000);
    return () => clearInterval(timer);
  }, [hasProcessing, load]);

  return { docs, loading, refresh: load };
}

在前端 KnowledgeBaseDetail 中替换原来的简单加载:

// 替换 useEffect(() => { load(); }, [id]);
const { docs, loading, refresh } = useDocumentPolling(id || "");
// 并用 setDocs 的逻辑就不需要了,docs 直接从 hook 获取

10. 整个数据流

用户上传 PDF
    │
    ▼
FastAPI 接收文件
    ├─ 存入 MinIO(对象存储)
    └─ 创建 Document 记录(PostgreSQL,status=pending)
    │
    ▼
Celery 异步任务 process_document
    │
    ├─ Step 1: DocumentParser.parse()
    │   └─ 从 MinIO 下载 → PyMuPDF 解析 → 纯文本
    │
    ├─ Step 2: RecursiveTextSplitter.split()
    │   └─ 文本 → 按段落/句子递归切分 → chunks
    │
    ├─ Step 3: EmbeddingService.encode()
    │   └─ chunks → BGE 模型 → 向量数组
    │
    ├─ Step 4: VectorStore.upsert_chunks()
    │   └─ 向量 + 元数据 → Qdrant collection
    │
    └─ 更新 Document 状态(PostgreSQL,status=ready)
    │
    ▼
前端轮询到 status=ready,显示"已完成"
用户现在可以提问了

11. 验证

# 上传文档后查看处理状态
curl http://localhost:8000/api/knowledge-bases/<kb_id>/documents \
  -H "Authorization: Bearer <token>"

# 响应中的状态变化
# 刚上传:status="pending"
# 几秒后:status="processing"
# 处理完:status="ready", chunk_count=42

# 直接查询 Qdrant
curl http://localhost:6333/collections/kb_<uuid>/points/0

12. 性能与异常处理

大文件处理

# 在 DocumentProcessor 中增加超时和进度
import asyncio


async def process_with_timeout(self, *args, timeout: int = 300, **kwargs):
    """带超时的文档处理。"""
    try:
        result = await asyncio.wait_for(
            self.process(*args, **kwargs),
            timeout=timeout,
        )
        return result
    except asyncio.TimeoutError:
        raise TimeoutError("文档处理超时(>5分钟),文件可能过大")

异常状态处理

状态 含义 用户看到
pending 等待处理 ⏳ 等待处理
processing 正在处理 🔄 处理中
ready 处理完成 ✅ 已完成
failed 处理失败 ❌ 处理失败(可查看错误信息)

支持的文档格式

格式 解析引擎 限制
PDF PyMuPDF 扫描件无文本层(后续可用 OCR)
TXT 直接读取 无限制
Markdown 直接读取 无限制
DOCX python-docx 不支持带复杂表格的文档

总结

今天完成了整个文档处理 Pipeline:

组件 说明
文档解析 PDF/TXT/MD/DOCX → 纯文本
文本切分 递归字符切分(chunk_size=512, overlap=128)
Embedding BGE 模型向量化
Qdrant 存储 Collection 管理 + 向量 upsert
Celery 异步 上传即处理 + 重试机制
前端轮询 实时显示处理状态

现在上传的文档会自动处理成可检索的知识。下一篇我们将实现向量检索与 RAG 问答——用户提问时从 Qdrant 检索相关内容,交给 LLM 生成回答。


本文是 《AI 全栈开发实战——做一个真正的产品》 系列的第 5 篇。
系列目录:

  1. ✅ 产品定义与架构设计
  2. ✅ 技术选型与项目初始化
  3. ✅ 用户系统
  4. ✅ 知识库与文档管理
  5. ✅ 文档处理 Pipeline ← 你在这里
  6. 📝 向量检索与 RAG 问答

本文由 Zyentor(智元界) 原创发布


本文发布于 Zyentor(智元界) —— AI 开发者社区

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐