第05章 数据存储与检索

5.1 PostgreSQL

5.1.1 数据模型设计

表设计最佳实践:

-- 用户表
CREATE TABLE users (
    id SERIAL PRIMARY KEY,
    username VARCHAR(50) UNIQUE NOT NULL,
    email VARCHAR(100) UNIQUE NOT NULL,
    password_hash VARCHAR(255) NOT NULL,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- 文章表
CREATE TABLE articles (
    id SERIAL PRIMARY KEY,
    title VARCHAR(255) NOT NULL,
    content TEXT,
    author_id INTEGER REFERENCES users(id),
    published BOOLEAN DEFAULT FALSE,
    published_at TIMESTAMP,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

索引设计:

-- 单列索引
CREATE INDEX idx_articles_author_id ON articles(author_id);

-- 复合索引
CREATE INDEX idx_articles_author_published ON articles(author_id, published);

-- 部分索引
CREATE INDEX idx_articles_published ON articles(id) WHERE published = TRUE;

5.1.2 JSON类型

JSONB字段操作:

-- 创建带JSONB字段的表
CREATE TABLE products (
    id SERIAL PRIMARY KEY,
    name VARCHAR(100),
    attributes JSONB
);

-- 插入JSON数据
INSERT INTO products (name, attributes) VALUES (
    'Laptop',
    '{"brand": "Apple", "model": "MacBook Pro", "specs": {"cpu": "M2", "ram": 16}}'
);

-- 查询JSON字段
SELECT attributes->>'brand' AS brand FROM products;

-- JSON路径查询
SELECT * FROM products WHERE attributes @> '{"brand": "Apple"}';

Python操作JSONB:

import psycopg2
import json

# 连接数据库
conn = psycopg2.connect("dbname=example user=postgres")
cur = conn.cursor()

# 插入JSON数据
product = {
    "name": "Laptop",
    "attributes": {
        "brand": "Apple", 
        "model": "MacBook Pro",
        "specs": {"cpu": "M2", "ram": 16}
    }
}

cur.execute("""
    INSERT INTO products (name, attributes)
    VALUES (%s, %s)
""", (product["name"], json.dumps(product["attributes"])))

conn.commit()

5.1.3 全文搜索

使用tsvector和tsquery:

-- 创建全文索引
ALTER TABLE articles ADD COLUMN search_vector tsvector;
UPDATE articles SET search_vector = to_tsvector('english', title || ' ' || content);
CREATE INDEX idx_articles_search ON articles USING gin(search_vector);

-- 全文搜索查询
SELECT title FROM articles 
WHERE search_vector @@ to_tsquery('english', 'data & storage');

-- 排名搜索
SELECT title, ts_rank(search_vector, to_tsquery('english', 'data storage')) AS rank
FROM articles 
WHERE search_vector @@ to_tsquery('english', 'data storage')
ORDER BY rank DESC;

案例1:PostgreSQL CRUD操作

import psycopg2
from psycopg2 import sql
from datetime import datetime

class PostgresCRUD:
    def __init__(self, dbname, user, password, host="localhost", port=5432):
        self.conn = psycopg2.connect(
            dbname=dbname, user=user, password=password, host=host, port=port
        )
        self.cur = self.conn.cursor()
    
    def create_user(self, username, email, password_hash):
        """创建用户"""
        self.cur.execute("""
            INSERT INTO users (username, email, password_hash)
            VALUES (%s, %s, %s)
            RETURNING id
        """, (username, email, password_hash))
        self.conn.commit()
        return self.cur.fetchone()[0]
    
    def get_user(self, user_id):
        """获取用户"""
        self.cur.execute("SELECT * FROM users WHERE id = %s", (user_id,))
        return self.cur.fetchone()
    
    def update_user(self, user_id, **kwargs):
        """更新用户"""
        set_clause = sql.SQL(", ").join(
            sql.SQL("{} = %s").format(sql.Identifier(k)) for k in kwargs
        )
        query = sql.SQL("UPDATE users SET {} WHERE id = %s").format(set_clause)
        self.cur.execute(query, list(kwargs.values()) + [user_id])
        self.conn.commit()
    
    def delete_user(self, user_id):
        """删除用户"""
        self.cur.execute("DELETE FROM users WHERE id = %s", (user_id,))
        self.conn.commit()
    
    def search_articles(self, query):
        """全文搜索文章"""
        self.cur.execute("""
            SELECT title, ts_rank(search_vector, to_tsquery('english', %s)) AS rank
            FROM articles 
            WHERE search_vector @@ to_tsquery('english', %s)
            ORDER BY rank DESC
            LIMIT 10
        """, (query, query))
        return self.cur.fetchall()
    
    def close(self):
        """关闭连接"""
        self.cur.close()
        self.conn.close()

# 使用示例
crud = PostgresCRUD(dbname="example", user="postgres", password="secret")
user_id = crud.create_user("john", "john@example.com", "hashed_password")
print(f"Created user with ID: {user_id}")
crud.close()

5.2 Redis

5.2.1 数据结构

Redis支持多种数据结构:

import redis

r = redis.Redis(host='localhost', port=6379, db=0)

# String
r.set('name', 'John')
print(r.get('name'))  # b'John'

# Hash
r.hset('user:1000', mapping={
    'username': 'john',
    'email': 'john@example.com',
    'age': '30'
})
print(r.hgetall('user:1000'))

# List
r.lpush('tasks', 'task1', 'task2', 'task3')
print(r.lrange('tasks', 0, -1))  # 从左到右获取所有元素

# Set
r.sadd('tags', 'python', 'redis', 'fastapi')
print(r.smembers('tags'))  # 获取所有成员

# Sorted Set
r.zadd('scores', {'Alice': 95, 'Bob': 88, 'Charlie': 92})
print(r.zrange('scores', 0, -1, withscores=True))  # 按分数排序

# HyperLogLog(基数统计)
r.pfadd('unique_visitors', 'user1', 'user2', 'user3')
print(r.pfcount('unique_visitors'))  # 3

5.2.2 发布订阅

发布者:

import redis

r = redis.Redis(host='localhost', port=6379, db=0)

# 发布消息
r.publish('news', 'Breaking news: Redis 7.0 released!')
r.publish('news', 'Another news item')

订阅者:

import redis

r = redis.Redis(host='localhost', port=6379, db=0)
pubsub = r.pubsub()

# 订阅频道
pubsub.subscribe('news')

# 监听消息
for message in pubsub.listen():
    if message['type'] == 'message':
        print(f"Received: {message['data'].decode()}")

5.2.3 Lua脚本

执行Lua脚本:

import redis

r = redis.Redis(host='localhost', port=6379, db=0)

# Lua脚本:原子性地增加计数器并返回新值
script = """
    local key = KEYS[1]
    local increment = tonumber(ARGV[1])
    local current = tonumber(redis.call('get', key) or '0')
    current = current + increment
    redis.call('set', key, current)
    return current
"""

# 加载并执行脚本
result = r.eval(script, 1, 'counter', 5)
print(f"Counter value: {result}")

5.2.4 缓存策略

缓存模式示例:

import redis
import time

r = redis.Redis(host='localhost', port=6379, db=0)

def get_data_from_db(key):
    """模拟从数据库获取数据"""
    time.sleep(1)  # 模拟数据库延迟
    return f"data for {key}"

def get_cached_data(key):
    """带缓存的获取数据"""
    # 先从缓存获取
    cached = r.get(key)
    if cached:
        print("Cache hit!")
        return cached.decode()
    
    # 缓存未命中,从数据库获取
    print("Cache miss!")
    data = get_data_from_db(key)
    
    # 设置缓存,过期时间5分钟
    r.setex(key, 300, data)
    return data

# 第一次调用(缓存未命中)
print(get_cached_data('user:1'))

# 第二次调用(缓存命中)
print(get_cached_data('user:1'))

案例2:Redis缓存实现

import redis
from functools import wraps
from datetime import timedelta

class RedisCache:
    def __init__(self, host='localhost', port=6379, db=0):
        self.client = redis.Redis(host=host, port=port, db=db)
    
    def cache(self, key_prefix, expire=timedelta(minutes=5)):
        """装饰器:缓存函数结果"""
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 生成缓存键
                key = f"{key_prefix}:{args}:{kwargs}"
                
                # 检查缓存
                cached = self.client.get(key)
                if cached:
                    return cached.decode()
                
                # 执行函数
                result = func(*args, **kwargs)
                
                # 设置缓存
                self.client.setex(key, int(expire.total_seconds()), str(result))
                
                return result
            return wrapper
        return decorator
    
    def invalidate(self, key):
        """删除缓存"""
        self.client.delete(key)
    
    def flush_pattern(self, pattern):
        """删除匹配模式的缓存"""
        keys = self.client.keys(pattern)
        if keys:
            self.client.delete(*keys)

# 使用示例
cache = RedisCache()

@cache.cache("user_profile")
def get_user_profile(user_id):
    """获取用户资料(模拟数据库查询)"""
    import time
    time.sleep(0.5)  # 模拟延迟
    return f"Profile for user {user_id}"

# 第一次调用
print(get_user_profile(1))  # 缓存未命中

# 第二次调用(缓存命中)
print(get_user_profile(1))  # 从缓存获取

# 失效缓存
cache.invalidate("user_profile:(1,):{}")

5.3 向量数据库Milvus

5.3.1 概念与架构

Milvus核心概念:

  • Collection(集合):类似于关系数据库中的表
  • Partition(分区):集合的逻辑划分,用于数据管理
  • Entity(实体):数据记录,包含向量和标量字段
  • Index(索引):加速向量相似度搜索

安装与连接:

# 安装pymilvus
pip install pymilvus

5.3.2 Collection管理

创建Collection:

from pymilvus import MilvusClient, DataType

# 连接Milvus
client = MilvusClient("http://localhost:19530")

# 创建Collection
schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=True
)

# 添加字段
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=384)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=500)

# 创建Collection
client.create_collection(
    collection_name="documents",
    schema=schema,
    metric_type="L2"  # 欧氏距离
)

5.3.3 Index类型

常用索引类型:

  • FLAT:暴力搜索,精度最高,速度最慢
  • IVF_FLAT:基于量化的索引,平衡精度和速度
  • IVF_SQ8:量化压缩版IVF_FLAT,节省内存
  • HNSW:基于图的索引,高维数据性能好
# 创建HNSW索引
client.create_index(
    collection_name="documents",
    field_name="embedding",
    index_type="HNSW",
    index_params={
        "M": 16,           # 每个节点的最大连接数
        "efConstruction": 200  # 构建时的搜索范围
    }
)

5.3.4 pymilvus API

插入和搜索向量:

import numpy as np

# 插入向量数据
data = [
    {"id": 1, "embedding": np.random.rand(384).tolist(), "text": "Document 1"},
    {"id": 2, "embedding": np.random.rand(384).tolist(), "text": "Document 2"},
    {"id": 3, "embedding": np.random.rand(384).tolist(), "text": "Document 3"},
]

client.insert(collection_name="documents", data=data)

# 向量搜索
query_vector = np.random.rand(384).tolist()
results = client.search(
    collection_name="documents",
    data=[query_vector],
    limit=2,
    search_params={"ef": 50}  # 搜索时的范围
)

print(results)

案例3:Milvus向量搜索完整示例

from pymilvus import MilvusClient, DataType
import numpy as np

class MilvusVectorDB:
    def __init__(self, uri="http://localhost:19530"):
        self.client = MilvusClient(uri)
        self.collection_name = "embeddings"
    
    def create_collection(self, dim=384):
        """创建向量集合"""
        if self.client.has_collection(self.collection_name):
            self.client.drop_collection(self.collection_name)
        
        schema = MilvusClient.create_schema(
            auto_id=True,
            enable_dynamic_field=True
        )
        schema.add_field("id", DataType.INT64, is_primary=True)
        schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
        schema.add_field("metadata", DataType.JSON)
        
        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            metric_type="COSINE"  # 余弦相似度
        )
        
        # 创建索引
        self.client.create_index(
            collection_name=self.collection_name,
            field_name="embedding",
            index_type="HNSW",
            index_params={"M": 16, "efConstruction": 200}
        )
    
    def insert_vectors(self, vectors, metadatas=None):
        """插入向量数据"""
        data = []
        for i, vec in enumerate(vectors):
            item = {"embedding": vec.tolist()}
            if metadatas:
                item["metadata"] = metadatas[i]
            data.append(item)
        
        return self.client.insert(collection_name=self.collection_name, data=data)
    
    def search_vectors(self, query_vector, top_k=10, filter=None):
        """向量搜索"""
        search_params = {"ef": 50}
        
        results = self.client.search(
            collection_name=self.collection_name,
            data=[query_vector.tolist()],
            limit=top_k,
            search_params=search_params,
            filter=filter
        )
        
        return results[0]
    
    def get_entity_by_id(self, entity_id):
        """根据ID获取实体"""
        return self.client.get(
            collection_name=self.collection_name,
            ids=[entity_id]
        )

# 使用示例
db = MilvusVectorDB()
db.create_collection(dim=384)

# 插入示例向量
vectors = np.random.rand(100, 384)
metadatas = [{"text": f"Document {i}"} for i in range(100)]
db.insert_vectors(vectors, metadatas)

# 搜索相似向量
query_vec = np.random.rand(384)
results = db.search_vectors(query_vec, top_k=5)

for res in results:
    print(f"ID: {res['id']}, Distance: {res['distance']}")

5.4 Elasticsearch

5.4.1 倒排索引

倒排索引原理:

文档1: "Python is a programming language"
文档2: "Elasticsearch is a search engine"
文档3: "Python Elasticsearch integration"

倒排索引:
"python"     -> [文档1, 文档3]
"programming" -> [文档1]
"elasticsearch" -> [文档2, 文档3]
"search"     -> [文档2]
"engine"     -> [文档2]
"integration" -> [文档3]

创建索引:

from elasticsearch import Elasticsearch

# 连接ES
es = Elasticsearch("http://localhost:9200")

# 创建索引
mapping = {
    "mappings": {
        "properties": {
            "title": {"type": "text"},
            "content": {"type": "text"},
            "author": {"type": "keyword"},
            "published_date": {"type": "date"}
        }
    }
}

es.indices.create(index="articles", body=mapping)

5.4.2 DSL查询

各种查询类型:

# 全文搜索
query = {
    "query": {
        "match": {
            "content": "data storage"
        }
    }
}
result = es.search(index="articles", body=query)

# 短语搜索
query = {
    "query": {
        "match_phrase": {
            "content": "machine learning"
        }
    }
}

# 多字段搜索
query = {
    "query": {
        "multi_match": {
            "query": "python",
            "fields": ["title", "content"]
        }
    }
}

# 布尔查询
query = {
    "query": {
        "bool": {
            "must": [{"match": {"content": "data"}}],
            "filter": [{"term": {"author": "John"}}],
            "must_not": [{"match": {"content": "deprecated"}}]
        }
    }
}

5.4.3 中文分词

使用中文分词器:

# 创建带中文分词的索引
mapping = {
    "settings": {
        "analysis": {
            "analyzer": {
                "chinese_analyzer": {
                    "tokenizer": "ik_max_word",
                    "filter": ["lowercase"]
                }
            }
        }
    },
    "mappings": {
        "properties": {
            "title": {"type": "text", "analyzer": "chinese_analyzer"},
            "content": {"type": "text", "analyzer": "chinese_analyzer"}
        }
    }
}

es.indices.create(index="chinese_docs", body=mapping)

5.4.4 聚合操作

聚合查询示例:

# 按作者分组统计
query = {
    "aggs": {
        "authors": {
            "terms": {
                "field": "author.keyword",
                "size": 10
            }
        }
    }
}

# 日期范围聚合
query = {
    "aggs": {
        "publication_over_time": {
            "date_histogram": {
                "field": "published_date",
                "calendar_interval": "month"
            }
        }
    }
}

# 平均值聚合
query = {
    "aggs": {
        "avg_score": {
            "avg": {
                "field": "score"
            }
        }
    }
}

案例4:Elasticsearch全文搜索

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk

class ESClient:
    def __init__(self, host="localhost", port=9200):
        self.client = Elasticsearch(f"http://{host}:{port}")
    
    def create_index(self, index_name, mapping=None):
        """创建索引"""
        if not self.client.indices.exists(index=index_name):
            if mapping:
                self.client.indices.create(index=index_name, body=mapping)
            else:
                self.client.indices.create(index=index_name)
    
    def index_documents(self, index_name, documents):
        """批量索引文档"""
        actions = [
            {
                "_index": index_name,
                "_id": doc.get("id"),
                "_source": doc
            }
            for doc in documents
        ]
        bulk(self.client, actions)
    
    def search(self, index_name, query, size=10):
        """执行搜索"""
        return self.client.search(index=index_name, body=query, size=size)
    
    def highlight_search(self, index_name, query_text, fields=["title", "content"]):
        """带高亮的搜索"""
        query = {
            "query": {
                "multi_match": {
                    "query": query_text,
                    "fields": fields
                }
            },
            "highlight": {
                "fields": {field: {} for field in fields}
            }
        }
        return self.client.search(index=index_name, body=query)

# 使用示例
es = ESClient()

# 创建索引
mapping = {
    "mappings": {
        "properties": {
            "title": {"type": "text"},
            "content": {"type": "text"},
            "author": {"type": "keyword"}
        }
    }
}
es.create_index("documents", mapping)

# 索引文档
documents = [
    {"id": "1", "title": "Python基础", "content": "Python是一种高级编程语言", "author": "John"},
    {"id": "2", "title": "Elasticsearch指南", "content": "ES是一个分布式搜索引擎", "author": "Alice"},
    {"id": "3", "title": "机器学习入门", "content": "机器学习是人工智能的一个分支", "author": "Bob"}
]
es.index_documents("documents", documents)

# 搜索
results = es.highlight_search("documents", "编程语言")
for hit in results["hits"]["hits"]:
    print(f"Title: {hit['_source']['title']}")
    print(f"Highlight: {hit.get('highlight', {})}")

5.5 Embedding模型

5.5.1 文本嵌入原理

文本嵌入流程:

文本输入 → Tokenizer分词 → 嵌入层 → 向量输出

例如:
输入: "Hello world"
输出: [0.123, -0.456, 0.789, ...] (维度通常为768或384)

常用Embedding模型:

  • Sentence-BERT:专注于句子级嵌入
  • OpenAI Embeddings:text-embedding-ada-002
  • BGE:中文开源模型
  • MiniLM:轻量级模型

5.5.2 维度选择

维度选择考虑因素:

因素 说明
模型能力 更高维度通常表示更强的表达能力
存储成本 高维度向量占用更多存储空间
搜索速度 维度越高,相似度计算越慢
应用场景 语义搜索通常用384-768维

5.5.3 批量编码

使用Sentence-BERT批量编码:

from sentence_transformers import SentenceTransformer
import numpy as np

# 加载模型
model = SentenceTransformer('all-MiniLM-L6-v2')

# 批量编码文本
sentences = [
    "Python是一种高级编程语言",
    "机器学习是人工智能的分支",
    "深度学习是机器学习的子领域",
    "神经网络由多个神经元组成"
]

# 生成嵌入向量
embeddings = model.encode(sentences)

print(f"向量形状: {embeddings.shape}")  # (4, 384)
print(f"第一个向量: {embeddings[0][:5]}")

案例5:Embedding服务封装

from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List

class EmbeddingService:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.model = SentenceTransformer(model_name)
        self.dim = self.model.get_sentence_embedding_dimension()
    
    def encode(self, texts: List[str], batch_size=32) -> np.ndarray:
        """批量编码文本"""
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            batch_embeddings = self.model.encode(batch)
            embeddings.append(batch_embeddings)
        
        return np.vstack(embeddings)
    
    def encode_single(self, text: str) -> np.ndarray:
        """编码单条文本"""
        return self.model.encode(text)
    
    def compute_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """计算余弦相似度"""
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    
    def rank_by_similarity(self, query_vec: np.ndarray, doc_vecs: np.ndarray) -> List[int]:
        """按相似度排序"""
        similarities = [self.compute_similarity(query_vec, doc_vec) for doc_vec in doc_vecs]
        return np.argsort(similarities)[::-1].tolist()

# 使用示例
service = EmbeddingService()

# 编码文档
documents = [
    "Python编程入门指南",
    "机器学习算法详解",
    "深度学习实战",
    "数据结构与算法"
]
doc_vectors = service.encode(documents)

# 查询
query = "Python机器学习"
query_vec = service.encode_single(query)

# 计算相似度
similarities = [service.compute_similarity(query_vec, dv) for dv in doc_vectors]
for doc, sim in zip(documents, similarities):
    print(f"{doc}: {sim:.4f}")

5.6 相似度算法

5.6.1 余弦相似度

余弦相似度公式:

sim(A, B) = (A · B) / (||A|| × ||B||)

实现:

import numpy as np

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """计算余弦相似度"""
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot_product / (norm1 * norm2)

# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(cosine_similarity(vec1, vec2))  # 0.9746

5.6.2 欧氏距离

欧氏距离公式:

distance(A, B) = sqrt(sum((A_i - B_i)^2))

实现:

def euclidean_distance(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """计算欧氏距离"""
    return np.linalg.norm(vec1 - vec2)

# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(euclidean_distance(vec1, vec2))  # 5.196

5.6.3 内积

内积相似度:

def dot_product(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """计算内积"""
    return np.dot(vec1, vec2)

# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(dot_product(vec1, vec2))  # 32

5.6.4 HNSW索引原理

HNSW(Hierarchical Navigable Small Worlds):

层级结构示意:
    Level 2:        A ---- B
                      \    /
                       \  /
    Level 1:        C -- D -- E
                    |    |    |
                    |    |    |
    Level 0:        F -- G -- H -- I

搜索过程:
1. 从顶层入口点开始
2. 在每层找到最近的邻居
3. 向下层移动,重复搜索
4. 在底层进行精确搜索

案例6:多种相似度算法对比

import numpy as np

class SimilarityCalculator:
    @staticmethod
    def cosine(vec1: np.ndarray, vec2: np.ndarray) -> float:
        """余弦相似度"""
        norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
        if norm == 0:
            return 0.0
        return np.dot(vec1, vec2) / norm
    
    @staticmethod
    def euclidean(vec1: np.ndarray, vec2: np.ndarray) -> float:
        """欧氏距离(归一化)"""
        dist = np.linalg.norm(vec1 - vec2)
        return 1 / (1 + dist)  # 转换为相似度
    
    @staticmethod
    def dot(vec1: np.ndarray, vec2: np.ndarray) -> float:
        """内积(归一化)"""
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    
    @staticmethod
    def manhattan(vec1: np.ndarray, vec2: np.ndarray) -> float:
        """曼哈顿距离(归一化)"""
        dist = np.sum(np.abs(vec1 - vec2))
        return 1 / (1 + dist)
    
    @staticmethod
    def jaccard(vec1: np.ndarray, vec2: np.ndarray) -> float:
        """Jaccard相似度(适用于二进制向量)"""
        intersection = np.sum(np.logical_and(vec1, vec2))
        union = np.sum(np.logical_or(vec1, vec2))
        if union == 0:
            return 0.0
        return intersection / union

# 使用示例
calculator = SimilarityCalculator()

vec1 = np.array([1, 2, 3, 4, 5])
vec2 = np.array([2, 3, 4, 5, 6])
vec3 = np.array([10, 20, 30, 40, 50])

print(f"余弦相似度 (v1, v2): {calculator.cosine(vec1, vec2):.4f}")
print(f"欧氏相似度 (v1, v2): {calculator.euclidean(vec1, vec2):.4f}")
print(f"余弦相似度 (v1, v3): {calculator.cosine(vec1, vec3):.4f}")

5.7 混合检索

5.7.1 向量+关键词检索

混合检索架构:

                    查询
                      │
        ┌─────────────┼─────────────┐
        ▼             ▼             ▼
    向量检索      关键词检索      语义检索
        │             │             │
        ▼             ▼             ▼
    Milvus       Elasticsearch   Embedding
        │             │             │
        └─────────────┼─────────────┘
                      ▼
                结果融合
                      │
                      ▼
                   最终结果

5.7.2 RRF(Reciprocal Rank Fusion)

RRF公式:

score(d) = sum(1 / (k + rank_i(d)))

其中:
- k: 常数(通常取60)
- rank_i(d): 结果d在第i个检索器中的排名

实现:

def reciprocal_rank_fusion(rankings: list, k: int = 60) -> dict:
    """计算RRF融合分数"""
    scores = {}
    
    for rank_list in rankings:
        for idx, doc_id in enumerate(rank_list):
            if doc_id not in scores:
                scores[doc_id] = 0
            scores[doc_id] += 1 / (k + idx + 1)  # 排名从1开始
    
    # 按分数排序
    sorted_docs = sorted(scores.items(), key=lambda x: -x[1])
    return dict(sorted_docs)

# 示例
vector_ranking = ["doc1", "doc3", "doc2", "doc4"]
keyword_ranking = ["doc2", "doc1", "doc5", "doc3"]

fused = reciprocal_rank_fusion([vector_ranking, keyword_ranking])
print(fused)
# {'doc1': 0.0325, 'doc2': 0.0325, 'doc3': 0.0308, 'doc4': 0.0164, 'doc5': 0.0164}

案例7:混合检索实现

from typing import List, Dict, Tuple
import numpy as np

class HybridRetriever:
    def __init__(self, vector_db, keyword_db, embedding_service):
        self.vector_db = vector_db
        self.keyword_db = keyword_db
        self.embedding_service = embedding_service
    
    def vector_search(self, query: str, top_k: int = 10) -> List[str]:
        """向量检索"""
        query_vec = self.embedding_service.encode_single(query)
        results = self.vector_db.search_vectors(query_vec, top_k=top_k)
        return [str(res["id"]) for res in results]
    
    def keyword_search(self, query: str, top_k: int = 10) -> List[str]:
        """关键词检索"""
        results = self.keyword_db.search(query, size=top_k)
        return [hit["_id"] for hit in results["hits"]["hits"]]
    
    def reciprocal_rank_fusion(self, rankings: List[List[str]], k: int = 60) -> Dict[str, float]:
        """RRF融合"""
        scores = {}
        for rank_list in rankings:
            for idx, doc_id in enumerate(rank_list):
                scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + idx + 1)
        return dict(sorted(scores.items(), key=lambda x: -x[1]))
    
    def hybrid_search(self, query: str, top_k: int = 10, rrf_k: int = 60) -> List[Tuple[str, float]]:
        """混合检索"""
        # 获取两个检索器的结果
        vector_results = self.vector_search(query, top_k)
        keyword_results = self.keyword_search(query, top_k)
        
        # 融合结果
        fused = self.reciprocal_rank_fusion([vector_results, keyword_results], k=rrf_k)
        
        # 返回前top_k结果
        return list(fused.items())[:top_k]

# 使用示例(伪代码)
# vector_db = MilvusVectorDB()
# keyword_db = ESClient()
# embedding_service = EmbeddingService()

# retriever = HybridRetriever(vector_db, keyword_db, embedding_service)
# results = retriever.hybrid_search("Python机器学习", top_k=5)
# for doc_id, score in results:
#     print(f"Document: {doc_id}, Score: {score:.4f}")

5.8 数据管道设计

5.8.1 ETL流程

ETL架构:

数据源 → 抽取(Extract) → 转换(Transform) → 加载(Load) → 数据仓库

详细流程:
1. Extract: 从数据库、API、文件等抽取原始数据
2. Transform: 清洗、转换、标准化数据
3. Load: 加载到目标存储系统

ETL示例:

import pandas as pd
import psycopg2
from sqlalchemy import create_engine

class ETLPipeline:
    def __init__(self, source_conn_str, target_conn_str):
        self.source_engine = create_engine(source_conn_str)
        self.target_engine = create_engine(target_conn_str)
    
    def extract(self, query: str) -> pd.DataFrame:
        """抽取数据"""
        return pd.read_sql(query, self.source_engine)
    
    def transform(self, df: pd.DataFrame) -> pd.DataFrame:
        """转换数据"""
        # 去除空值
        df = df.dropna()
        
        # 标准化日期格式
        df['created_at'] = pd.to_datetime(df['created_at']).dt.strftime('%Y-%m-%d')
        
        # 添加处理时间
        df['processed_at'] = pd.Timestamp.now()
        
        return df
    
    def load(self, df: pd.DataFrame, table_name: str):
        """加载数据"""
        df.to_sql(table_name, self.target_engine, if_exists='append', index=False)
    
    def run(self, query: str, table_name: str):
        """运行完整ETL流程"""
        df = self.extract(query)
        df = self.transform(df)
        self.load(df, table_name)
        print(f"Loaded {len(df)} records")

# 使用示例
pipeline = ETLPipeline(
    source_conn_str="postgresql://user:pass@localhost/source_db",
    target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run("SELECT * FROM users", "users_staging")

5.8.2 增量更新

增量更新策略:

class IncrementalETL(ETLPipeline):
    def __init__(self, source_conn_str, target_conn_str):
        super().__init__(source_conn_str, target_conn_str)
        self.last_sync_time = None
    
    def get_last_sync_time(self, table_name: str) -> str:
        """获取上次同步时间"""
        query = f"SELECT MAX(processed_at) FROM {table_name}"
        with self.target_engine.connect() as conn:
            result = conn.execute(query).fetchone()
            return result[0] if result[0] else "1970-01-01"
    
    def incremental_extract(self, table_name: str) -> pd.DataFrame:
        """增量抽取"""
        self.last_sync_time = self.get_last_sync_time(table_name)
        query = f"""
            SELECT * FROM {table_name}
            WHERE updated_at > '{self.last_sync_time}'
        """
        return self.extract(query)
    
    def run_incremental(self, source_table: str, target_table: str):
        """运行增量ETL"""
        df = self.incremental_extract(source_table)
        if not df.empty:
            df = self.transform(df)
            self.load(df, target_table)
            print(f"Incrementally loaded {len(df)} records")
        else:
            print("No new records to process")

# 使用示例
pipeline = IncrementalETL(
    source_conn_str="postgresql://user:pass@localhost/source_db",
    target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run_incremental("users", "users_staging")

5.8.3 一致性保障

数据一致性策略:

import contextlib

class TransactionalETL(ETLPipeline):
    def __init__(self, source_conn_str, target_conn_str):
        super().__init__(source_conn_str, target_conn_str)
    
    @contextlib.contextmanager
    def transaction(self):
        """事务上下文管理器"""
        conn = self.target_engine.connect()
        trans = conn.begin()
        try:
            yield conn
            trans.commit()
            print("Transaction committed")
        except Exception as e:
            trans.rollback()
            print(f"Transaction rolled back: {e}")
            raise
        finally:
            conn.close()
    
    def run_with_transaction(self, query: str, table_name: str):
        """带事务的ETL"""
        df = self.extract(query)
        df = self.transform(df)
        
        with self.transaction():
            # 先删除可能存在的重复数据
            with self.target_engine.connect() as conn:
                conn.execute(f"DELETE FROM {table_name} WHERE id IN ({','.join(map(str, df['id'].tolist()))})")
            
            self.load(df, table_name)
            print(f"Successfully loaded {len(df)} records")

# 使用示例
pipeline = TransactionalETL(
    source_conn_str="postgresql://user:pass@localhost/source_db",
    target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run_with_transaction("SELECT * FROM orders WHERE status = 'completed'", "completed_orders")

案例8:完整数据管道

import pandas as pd
from sqlalchemy import create_engine
from datetime import datetime, timedelta
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataPipeline:
    def __init__(self, config):
        self.config = config
        self.source_engine = create_engine(config["source"])
        self.target_engine = create_engine(config["target"])
        
    def extract(self, table_name: str, days: int = 7) -> pd.DataFrame:
        """抽取最近N天的数据"""
        cutoff_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d')
        query = f"""
            SELECT * FROM {table_name}
            WHERE created_at >= '{cutoff_date}'
        """
        logger.info(f"Extracting data from {table_name} since {cutoff_date}")
        return pd.read_sql(query, self.source_engine)
    
    def validate(self, df: pd.DataFrame) -> bool:
        """数据验证"""
        required_columns = ['id', 'name', 'created_at']
        if not all(col in df.columns for col in required_columns):
            logger.error("Missing required columns")
            return False
        
        if df['id'].duplicated().any():
            logger.error("Duplicate IDs found")
            return False
        
        logger.info("Data validation passed")
        return True
    
    def transform(self, df: pd.DataFrame) -> pd.DataFrame:
        """数据转换"""
        logger.info("Transforming data...")
        
        # 清洗
        df = df.drop_duplicates(subset=['id'])
        
        # 类型转换
        df['created_at'] = pd.to_datetime(df['created_at'])
        
        # 派生字段
        df['year_month'] = df['created_at'].dt.strftime('%Y-%m')
        df['processing_time'] = datetime.now()
        
        return df
    
    def load(self, df: pd.DataFrame, table_name: str, mode: str = 'append'):
        """加载数据"""
        logger.info(f"Loading {len(df)} records to {table_name}")
        df.to_sql(table_name, self.target_engine, if_exists=mode, index=False)
    
    def monitor(self, df: pd.DataFrame, step: str):
        """监控数据质量"""
        stats = {
            "step": step,
            "record_count": len(df),
            "columns": list(df.columns),
            "null_counts": df.isnull().sum().to_dict(),
            "timestamp": datetime.now().isoformat()
        }
        logger.info(f"Monitoring stats: {stats}")
    
    def run(self, table_name: str):
        """运行完整管道"""
        try:
            # 抽取
            df = self.extract(table_name)
            self.monitor(df, "extract")
            
            # 验证
            if not self.validate(df):
                return
            
            # 转换
            df = self.transform(df)
            self.monitor(df, "transform")
            
            # 加载
            self.load(df, table_name)
            
            logger.info("Pipeline completed successfully")
        except Exception as e:
            logger.error(f"Pipeline failed: {e}", exc_info=True)
            raise

# 使用配置
config = {
    "source": "postgresql://user:pass@localhost/source_db",
    "target": "postgresql://user:pass@localhost/data_warehouse"
}

# 运行管道
pipeline = DataPipeline(config)
pipeline.run("user_activity")

小结

本章介绍了数据存储与检索的核心内容,包括:

  1. PostgreSQL:数据模型设计、索引优化、JSON类型、全文搜索
  2. Redis:数据结构、发布订阅、Lua脚本、缓存策略
  3. Milvus向量数据库:Collection管理、Index类型、pymilvus API
  4. Elasticsearch:倒排索引、DSL查询、中文分词、聚合操作
  5. Embedding模型:文本嵌入原理、维度选择、批量编码
  6. 相似度算法:余弦相似度、欧氏距离、内积、HNSW索引原理
  7. 混合检索:向量+关键词检索、RRF得分融合策略
  8. 数据管道设计:ETL流程、增量更新、一致性保障

掌握这些技术对于构建现代化的AI应用,特别是RAG系统,至关重要。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐