Embeddings Practical Exercises

概述

本文基于 Text Embeddings.py 实战代码,系统梳理从 Embedding 向量化到语义搜索、向量数据库、Reranker 精排的完整技术链路。


1. Embedding 与向量化

什么是 Embedding

Embedding(嵌入) 是将文本、图片、音频等非结构化数据转换为固定维度向量的技术。语义相近的内容在向量空间中距离更近。

from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embedding = model.encode("I love AI")
# shape: (384,) — 单个句子
# shape: (N, 384) — N 个句子

返回值类型

model.encode() 返回 numpy.ndarray

  • 单句子 → 一维向量 (384,)
  • 多句子 → 二维矩阵 (N, 384),每行是一个句子的向量

2. 相似度计算

Cosine Similarity(余弦相似度)

余弦相似度衡量两个向量方向的相似程度,范围 [-1, 1]:

cosine_similarity(a, b) = dot(a, b) / (||a|| * ||b||)
import numpy as np

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

a = model.encode("I love AI")
b = model.encode("I like machine learning")
c = model.encode("The sky is blue")

print(cosine_similarity(a, b))  # 高(语义相关)
print(cosine_similarity(a, c))  # 低(语义无关)

批量计算

无需 for 循环,直接用矩阵乘法:

query_embedding = model.encode(query)  # (384,)
doc_embeddings = model.encode(documents)  # (N, 384)

# 点积即分子部分
scores = np.dot(doc_embeddings, query_embedding)

# 归一化
query_norm = np.linalg.norm(query_embedding)
doc_norm = np.linalg.norm(doc_embeddings, axis=1)
scores = scores / (query_norm * doc_norm)

3. 语义搜索(Semantic Search)

核心流程

用户查询 → Query Encoding → 向量匹配 → 排序 → Top-K 结果

三步详解

  1. Query Encoding:用户查询转换为向量
  2. Vector Matching:计算查询向量与所有文档向量的相似度
  3. Semantic Ranking:按相似度排序返回 Top-K

基础实现

def semantic_search(query, documents, model, top_k=2):
    query_embedding = model.encode(query)
    doc_embeddings = model.encode(documents)

    # 批量计算余弦相似度
    scores = np.dot(doc_embeddings, query_embedding)
    query_norm = np.linalg.norm(query_embedding)
    doc_norm = np.linalg.norm(doc_embeddings, axis=1)
    scores = scores / (query_norm * doc_norm)

    # 排序取 Top-K
    pairs = list(zip(documents, scores))
    pairs = sorted(pairs, key=lambda x: x[1], reverse=True)
    return pairs[:top_k]

4. 手写向量数据库(VectorDB)

为什么需要向量数据库

避免每次搜索都重新计算所有文档的 embedding,预计算后存储。

VectorDB 类实现

class VectorDB:
    def __init__(self, model):
        self.model = model
        self.documents = []
        self.embeddings = None

    def add(self, documents):
        new_embeddings = self.model.encode(documents)
        self.documents.extend(documents)

        # 初始化或拼接
        if self.embeddings is None:
            self.embeddings = new_embeddings
        else:
            self.embeddings = np.vstack([self.embeddings, new_embeddings])

    def _cosine_similarity(self, query_vec, doc_matrix):
        dot = np.dot(doc_matrix, query_vec)
        query_norm = np.linalg.norm(query_vec)
        doc_norm = np.linalg.norm(doc_matrix, axis=1)
        return dot / (query_norm * doc_norm)

    def search(self, query, top_k=3):
        query_embedding = self.model.encode(query)
        scores = self._cosine_similarity(query_embedding, self.embeddings)

        top_indices = np.argsort(scores)[::-1][:top_k]
        return [{"document": self.documents[idx], "score": float(scores[idx])}
                for idx in top_indices]

5. FAISS 向量数据库

什么是 FAISS

Facebook AI Similarity Search (FAISS) 是 Facebook 开发的向量检索库,支持海量向量的高效相似度搜索。

FAISS 索引类型

类型 说明 特点
IndexFlatL2 暴力搜索(精确) 精度最高,速度最慢
IndexFlatIP 内积索引(归一化后等价于余弦相似度) 精度最高,速度最慢
IndexIVFFlat 倒排文件索引 快但需要训练
IndexHNSW Hierarchical Navigable Small World 高速近似搜索

向量归一化

使用 faiss.normalize_L2() 将向量归一化,使内积等价于余弦相似度:

import faiss

doc_embeddings = model.encode(documents).astype("float32")
faiss.normalize_L2(doc_embeddings)  # 归一化到单位长度

index = faiss.IndexFlatIP(dim)  # 内积索引
index.add(doc_embeddings)

搜索

query_embedding = model.encode([query]).astype("float32")
faiss.normalize_L2(query_embedding)

distances, indices = index.search(query_embedding, k)
# distances: 距离(越小越相似)
# indices: 对应文档下标

6. FaissVectorDB 封装类

完整实现

class FaissVectorDB:
    def __init__(self, model):
        self.model = model
        self.documents = []
        self.index = None

    def add(self, documents):
        self.documents.extend(documents)
        doc_embeddings = self.model.encode(documents).astype("float32")
        faiss.normalize_L2(doc_embeddings)

        if self.index is None:
            self.index = faiss.IndexFlatIP(doc_embeddings.shape[1])
        self.index.add(doc_embeddings)

    def search(self, query, top_k=3):
        query_embedding = self.model.encode([query]).astype("float32")
        faiss.normalize_L2(query_embedding)

        distances, indices = self.index.search(query_embedding, top_k)
        return [{"document": self.documents[idx], "similarity": distances[0][i]}
                for i, idx in enumerate(indices[0])]

    def save(self, path="./"):
        db_path = os.path.join(path, "db")
        os.makedirs(db_path, exist_ok=True)

        faiss.write_index(self.index, os.path.join(db_path, "index.faiss"))
        with open(os.path.join(db_path, "docs.pkl"), "wb") as f:
            pickle.dump(self.documents, f)

    def load(self, path="./"):
        db_path = os.path.join(path, "db")
        self.index = faiss.read_index(os.path.join(db_path, "index.faiss"))
        with open(os.path.join(db_path, "docs.pkl"), "rb") as f:
            self.documents = pickle.load(f)

持久化

  • index.faiss:FAISS 索引文件
  • docs.pkl:原始文档列表(FAISS 只存向量,不存原文)

7. Bi-Encoder 与 Cross-Encoder

Bi-Encoder(双编码器)

Sentence A → BERT Encoder → Vector A
Sentence B → BERT Encoder → Vector B
         ↓
    Cosine Similarity
  • 分别独立编码 query 和 document
  • 可预计算 document embeddings,速度快
  • 精度中等,无法充分交互

Cross-Encoder(交叉编码器)

[CLS] Sentence A [SEP] Sentence B [SEP] → BERT Encoder → Score
  • 拼接后联合编码,query 和 document 充分交互
  • 精度高,但无法预计算,每次查询需重新编码
  • 适合小规模精排

对比

特性 Bi-Encoder (SBERT) Cross-Encoder
编码方式 分别独立编码 拼接后联合编码
计算速度 快(可预计算) 慢(实时计算)
精度 中等
适用场景 语义搜索(大规模候选) 重排序(小规模精排)
存储需求 需存储向量 无需存储向量

8. Reranker 精排

为什么需要 Reranker

向量检索只能找到语义相近的候选文档,但无法判断文档与查询的相关程度。Reranker 在粗排后做精排,进一步提升结果质量。

Pipeline

文档 → Embedding → Vector Database
                              ↓
用户查询 → Embedding → ANN 检索 → Top-K → Reranker 精排 → 结果

实现

from sentence_transformers import CrossEncoder

reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')

# 粗排召回 Top-K
res = db.search(query, 12)

# 构建 (query, document) 对
pairs = [(query, r["document"]) for r in res]

# Cross-Encoder 打分
scores = reranker.predict(pairs)

reranked = sorted(
    [{"document": r["document"], "score": scores[i]} for i, r in enumerate(res)],
    key=lambda x: x["score"],
    reverse=True
)

final = reranked[:5]

9. 完整语义搜索流程

# 1. 初始化模型
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')

# 2. 构建向量数据库
db = FaissVectorDB(model)
db.add(documents)
db.save()

# 3. 加载并搜索
db.load()
query = "How can I build a backend system?"
res = db.search(query, 12)

# 4. Reranker 精排
pairs = [(query, r["document"]) for r in res]
scores = reranker.predict(pairs)

reranked = sorted(
    [{"document": r["document"], "score": scores[i]} for i, r in enumerate(res)],
    key=lambda x: x["score"],
    reverse=True
)

final = reranked[:5]

10. 关键参数说明

参数/函数 说明
model.encode() 将文本转为 numpy ndarray
faiss.normalize_L2() 将向量归一化到单位长度
IndexFlatIP 内积索引(归一化后等价余弦相似度)
IndexFlatL2 L2 距离索引
index.search() 返回 distances 和 indices
CrossEncoder.predict() 对 (query, doc) 对打分

知识关联

Embedding 向量化
      ↓
语义搜索(Bi-Encoder)
      ↓
向量数据库(FAISS)
      ↓
Reranker 精排(Cross-Encoder)
      ↓
RAG / 推荐系统 / 异常检测

参考资源

完整练习源码(仅供参考)

import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# 用于对search找到的top-k文档进行rerank(cross encode)实现精排
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')

# documents = [                                                                                                
#     "Apple released a new iPhone with improved camera system",
#     "Microsoft announced new features for Windows 12",                                                       
#     "The stock market rallied today with tech stocks leading",
#     "Scientists discovered a new species of dinosaur in Argentina",                                          
#     "A new AI model achieves state-of-the-art results on benchmarks",
#     "Tesla unveiled their latest electric vehicle with longer range",                                        
#     "The government announced new policies for renewable energy",                                            
#     "Manchester United won the football match 3-1",                                                          
# ] 

# document_embeddings = model.encode(documents)

# print(f"文档数量: {len(documents)}")
# print(f"向量维度: {document_embeddings.shape}")
# print(f"向量形状: {document_embeddings.shape}")

# encode()函数返回值
# print(type(document_embeddings))
# <class 'numpy.ndarray'>
# 说明返回值就是一个一维的浮点向量
# 如果是多个句子就是一个二维向量,每一行就是一个句子的向量 -> (8, 384)


# cos相似度
# def cosine_similarity(a, b):
#     return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# a = model.encode("I love AI")
# b = model.encode("I like machine learning")
# c = model.encode("The sky is blue")

# print(cosine_similarity(a, b))  # 应该更高
# print(cosine_similarity(a, c))  # 应该更低

# 测试文档
# documents = [
#     # AI / ML
#     "Machine learning is a subset of artificial intelligence",
#     "Deep learning uses neural networks with many layers",
#     "Natural language processing helps computers understand text",

#     # Backend
#     "Backend development involves databases and APIs",
#     "RESTful APIs are commonly used in web services",
#     "Redis is an in-memory database used for caching",
#     "Backend systems include servers, databases, and APIs",
#     "Building scalable backend systems requires system design knowledge",
#     "Backend engineers design REST APIs and handle data processing",

#     # Frontend
#     "Frontend development focuses on user interfaces",
#     "React is a popular JavaScript library for building UI",
#     "CSS is used to style web pages",

#     # Random / noise
#     "The weather today is sunny with a clear sky",
#     "Football is a popular sport in Europe",
#     "Cooking pasta requires boiling water"
# ]

documents = [
    "Backend systems include servers, databases, and APIs",
    "Building scalable backend systems requires system design knowledge",
    "Backend development involves databases and APIs",
    "REST APIs are commonly used for backend communication",
    "Microservices architecture helps scale backend systems",
    "Databases can be SQL or NoSQL depending on the use case",
    "Redis is often used as an in-memory cache in backend systems",
    "Load balancing distributes traffic across multiple servers",
    "A message queue like Kafka helps decouple backend services",
    "Authentication and authorization are critical backend components",
    "System design interviews often involve designing scalable services",
    "Horizontal scaling means adding more machines to handle load",
    "Vertical scaling means increasing resources of a single machine",
    "Backend performance can be improved using caching strategies",
    "Indexing in databases improves query performance",
    "SQL joins are used to combine data from multiple tables",
    "NoSQL databases are suitable for flexible schema requirements",
    "API gateways manage requests between clients and services",
    "Docker is used to containerize backend applications",
    "Kubernetes manages container orchestration at scale",
    "Latency is the time delay between request and response",
    "Throughput measures how many requests a system can handle",
    "CAP theorem explains tradeoffs in distributed systems",
    "Consistency ensures all nodes see the same data",
    "Availability ensures system responds even under failures",
    "Partition tolerance allows system to operate despite network issues",
    "Backend logging is important for debugging and monitoring",
    "Prometheus and Grafana are used for system monitoring",
    "CI/CD pipelines automate backend deployment",
    "Version control systems like Git manage code changes",
    "Unit testing ensures individual backend components work correctly",
    "Integration testing checks interaction between services",
    "GraphQL provides an alternative to REST APIs",
    "WebSockets enable real-time backend communication",
    "JWT is commonly used for stateless authentication",
    "OAuth is an authorization framework for third-party access",
    "Backend security includes encryption and secure communication",
    "TLS encrypts data in transit between client and server",
    "Database normalization reduces data redundancy",
    "Denormalization can improve read performance in some systems"
]


# doc_embeddings = model.encode(documents)
# print(doc_embeddings.shape)
# print(doc_embeddings.dtype)


# 语义搜索(semantic search)
# 1. 先对query做encode
# 2. 对documents做encode
# 3. compute similarity
# 4. sort
# 5. return top_k
# def semantic_search(query, documents, model, top_k=2):
#     query_embedding = model.encode(query)
#     doc_embeddings = model.encode(documents)
#     # 这里不使用for循环O(N)计算
#     scores = np.dot(doc_embeddings, query_embedding)
#     # 计算模值
#     query_norm = np.linalg.norm(query_embedding)
#     doc_norm = np.linalg.norm(doc_embeddings, axis = 1) # 对每一行(每个 document)计算 norm

#     scores = scores / (query_norm * doc_norm)
#     pairs = list(zip(documents, scores))
#     pairs = sorted(pairs, key = lambda x: x[1], reverse=True)
#     return pairs[:top_k]


# res = semantic_search(
#     "How can I build a backend systems?",
#     documents,
#     model
# )

# print(res)


# 为了避免重复计算文档的embedding vector,这里手写一个Vector DB
class VectorDB:
    def __init__(self, model):
        self.model = model
        self.documents = []
        self.embeddings = None

    def add(self, documents):
        # encode这次添加的所有documents
        new_embeddings = self.model.encode(documents)
        # 将原始文档添加到documents
        self.documents.extend(documents)
        # 3. 初始化 or 拼接 embeddings
        if self.embeddings is None:
            self.embeddings = new_embeddings
        else:
            self.embeddings = np.vstack([
                self.embeddings,
                new_embeddings
            ])

    def _cosine_similarity(self, query_vec, doc_matrix):
        # 点积
        dot = np.dot(doc_matrix, query_vec)

        # 模长
        query_norm = np.linalg.norm(query_vec)
        doc_norm = np.linalg.norm(doc_matrix, axis=1)

        return dot / (query_norm * doc_norm)

    def search(self, query, top_k = 3):
        # encode query
        query_embedding = self.model.encode(query)
        # caculate cosine similarity
        scores = self._cosine_similarity(query_embedding, self.embeddings)

        top_indices = np.argsort(scores)[::-1][:top_k]
        results = []
        for idx in top_indices:
            results.append({
                "document": self.documents[idx],
                "score": float(scores[idx])
            })

        return results


        
# db = VectorDB(model)
# db.add(documents)
# res = db.search("How can I build a backend system?")
# print(res)


# 用FAISS构建向量数据库
# doc_embeddings = model.encode(documents)
# dim = doc_embeddings.shape[1]

# Flat = 暴力搜索(精确)
# L2   = 欧几里得距离
# index = faiss.IndexFlatL2(dim)
# index.add(doc_embeddings)

# query = "How to build backend system?"
# 保证是二维向量所以使用[query]
# query_embedding = model.encode([query])

# k = 3 # top_k
# distances, indices = index.search(query_embedding, k)
# 其中distance是距离距离(越小越相似), indices是最相似文档下标
# print(distance)
# print(indices)

# 其中faiss只存向量,所以需要map一下原始文档和下标
# res = []
# for i, idx in enumerate(indices[0]):
#     res.append({
#         "document": documents[idx],
#         "distance": distances[0][i]
#     })
# print(res)


# 用FAISS构建封装向量数据库(使用cosine similarity来进行排序)
class FaissVectorDB:
    def __init__(self, model):
        self.model = model
        self.documents = []
        self.index = None

    def add(self, documents):
        self.documents.extend(documents)
        doc_embeddings = self.model.encode(documents).astype("float32")
        # 进行归一化,让向量长度相同
        faiss.normalize_L2(doc_embeddings)
        # 根据dim(维度创建索引)
        if self.index is None:
            self.index = faiss.IndexFlatIP(doc_embeddings.shape[1])
        self.index.add(doc_embeddings)


    def search(self, query, top_k = 3):
        query_embedding = self.model.encode([query]).astype("float32")
        faiss.normalize_L2(query_embedding)

        distances, indices = self.index.search(query_embedding, top_k)
        res = []
        for i, idx in enumerate(indices[0]):
            res.append({
                "documents": self.documents[idx],
                "similarity": distances[0][i]
            })
        return res
    

    def save(self, path = "./"):
        db_path = os.path.join(path, "db")
        os.makedirs(db_path, exist_ok=True)

        if self.index is None:
            raise ValueError("No index to save")

        faiss.write_index(self.index, os.path.join(db_path, "index.faiss"))
        print("index.faiss save success")

        with open(os.path.join(db_path, "docs.pkl"), "wb") as f:
            pickle.dump(self.documents, f)
        print("documents save success")



    def load(self, path = "./"):
        db_path = os.path.join(path, "db")

        index_path = os.path.join(db_path, "index.faiss")
        docs_path = os.path.join(db_path, "docs.pkl")

        if not os.path.exists(index_path):
            raise FileNotFoundError("FAISS index file not found")
        if not os.path.exists(docs_path):
            raise FileNotFoundError("Documents file not found")       

        self.index = faiss.read_index(index_path) 
        print("index.faiss load success")
        with open(docs_path, "rb") as f:
            self.documents = pickle.load(f)
        print("documents load success")



db = FaissVectorDB(model)
# 首次启动
# db.add(documents)
# db.save()

# 加载数据库
query = "How can I build a backend system?"
db.load()
res = db.search(query, 12)

print("============ Top-k ============")
for r in res:
    print(r)


# 进行rerank
# 构建(query, doc)
pairs = [(query, r["documents"]) for r in res]
# 对每个文档与query的相关程度进行打分
scores = reranker.predict(pairs)

reranked = []
for i, r in enumerate(res):
    reranked.append({
        "document": r["documents"],
        "score": scores[i]
    })

reranked = sorted(reranked, key = lambda x: x["score"], reverse=True)

final = reranked[:5]
print("============ Rerank ============")
for f in final:
    print(f)
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐