第05章_数据存储与检索
·
第05章 数据存储与检索
5.1 PostgreSQL
5.1.1 数据模型设计
表设计最佳实践:
-- 用户表
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- 文章表
CREATE TABLE articles (
id SERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
content TEXT,
author_id INTEGER REFERENCES users(id),
published BOOLEAN DEFAULT FALSE,
published_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
索引设计:
-- 单列索引
CREATE INDEX idx_articles_author_id ON articles(author_id);
-- 复合索引
CREATE INDEX idx_articles_author_published ON articles(author_id, published);
-- 部分索引
CREATE INDEX idx_articles_published ON articles(id) WHERE published = TRUE;
5.1.2 JSON类型
JSONB字段操作:
-- 创建带JSONB字段的表
CREATE TABLE products (
id SERIAL PRIMARY KEY,
name VARCHAR(100),
attributes JSONB
);
-- 插入JSON数据
INSERT INTO products (name, attributes) VALUES (
'Laptop',
'{"brand": "Apple", "model": "MacBook Pro", "specs": {"cpu": "M2", "ram": 16}}'
);
-- 查询JSON字段
SELECT attributes->>'brand' AS brand FROM products;
-- JSON路径查询
SELECT * FROM products WHERE attributes @> '{"brand": "Apple"}';
Python操作JSONB:
import psycopg2
import json
# 连接数据库
conn = psycopg2.connect("dbname=example user=postgres")
cur = conn.cursor()
# 插入JSON数据
product = {
"name": "Laptop",
"attributes": {
"brand": "Apple",
"model": "MacBook Pro",
"specs": {"cpu": "M2", "ram": 16}
}
}
cur.execute("""
INSERT INTO products (name, attributes)
VALUES (%s, %s)
""", (product["name"], json.dumps(product["attributes"])))
conn.commit()
5.1.3 全文搜索
使用tsvector和tsquery:
-- 创建全文索引
ALTER TABLE articles ADD COLUMN search_vector tsvector;
UPDATE articles SET search_vector = to_tsvector('english', title || ' ' || content);
CREATE INDEX idx_articles_search ON articles USING gin(search_vector);
-- 全文搜索查询
SELECT title FROM articles
WHERE search_vector @@ to_tsquery('english', 'data & storage');
-- 排名搜索
SELECT title, ts_rank(search_vector, to_tsquery('english', 'data storage')) AS rank
FROM articles
WHERE search_vector @@ to_tsquery('english', 'data storage')
ORDER BY rank DESC;
案例1:PostgreSQL CRUD操作
import psycopg2
from psycopg2 import sql
from datetime import datetime
class PostgresCRUD:
def __init__(self, dbname, user, password, host="localhost", port=5432):
self.conn = psycopg2.connect(
dbname=dbname, user=user, password=password, host=host, port=port
)
self.cur = self.conn.cursor()
def create_user(self, username, email, password_hash):
"""创建用户"""
self.cur.execute("""
INSERT INTO users (username, email, password_hash)
VALUES (%s, %s, %s)
RETURNING id
""", (username, email, password_hash))
self.conn.commit()
return self.cur.fetchone()[0]
def get_user(self, user_id):
"""获取用户"""
self.cur.execute("SELECT * FROM users WHERE id = %s", (user_id,))
return self.cur.fetchone()
def update_user(self, user_id, **kwargs):
"""更新用户"""
set_clause = sql.SQL(", ").join(
sql.SQL("{} = %s").format(sql.Identifier(k)) for k in kwargs
)
query = sql.SQL("UPDATE users SET {} WHERE id = %s").format(set_clause)
self.cur.execute(query, list(kwargs.values()) + [user_id])
self.conn.commit()
def delete_user(self, user_id):
"""删除用户"""
self.cur.execute("DELETE FROM users WHERE id = %s", (user_id,))
self.conn.commit()
def search_articles(self, query):
"""全文搜索文章"""
self.cur.execute("""
SELECT title, ts_rank(search_vector, to_tsquery('english', %s)) AS rank
FROM articles
WHERE search_vector @@ to_tsquery('english', %s)
ORDER BY rank DESC
LIMIT 10
""", (query, query))
return self.cur.fetchall()
def close(self):
"""关闭连接"""
self.cur.close()
self.conn.close()
# 使用示例
crud = PostgresCRUD(dbname="example", user="postgres", password="secret")
user_id = crud.create_user("john", "john@example.com", "hashed_password")
print(f"Created user with ID: {user_id}")
crud.close()
5.2 Redis
5.2.1 数据结构
Redis支持多种数据结构:
import redis
r = redis.Redis(host='localhost', port=6379, db=0)
# String
r.set('name', 'John')
print(r.get('name')) # b'John'
# Hash
r.hset('user:1000', mapping={
'username': 'john',
'email': 'john@example.com',
'age': '30'
})
print(r.hgetall('user:1000'))
# List
r.lpush('tasks', 'task1', 'task2', 'task3')
print(r.lrange('tasks', 0, -1)) # 从左到右获取所有元素
# Set
r.sadd('tags', 'python', 'redis', 'fastapi')
print(r.smembers('tags')) # 获取所有成员
# Sorted Set
r.zadd('scores', {'Alice': 95, 'Bob': 88, 'Charlie': 92})
print(r.zrange('scores', 0, -1, withscores=True)) # 按分数排序
# HyperLogLog(基数统计)
r.pfadd('unique_visitors', 'user1', 'user2', 'user3')
print(r.pfcount('unique_visitors')) # 3
5.2.2 发布订阅
发布者:
import redis
r = redis.Redis(host='localhost', port=6379, db=0)
# 发布消息
r.publish('news', 'Breaking news: Redis 7.0 released!')
r.publish('news', 'Another news item')
订阅者:
import redis
r = redis.Redis(host='localhost', port=6379, db=0)
pubsub = r.pubsub()
# 订阅频道
pubsub.subscribe('news')
# 监听消息
for message in pubsub.listen():
if message['type'] == 'message':
print(f"Received: {message['data'].decode()}")
5.2.3 Lua脚本
执行Lua脚本:
import redis
r = redis.Redis(host='localhost', port=6379, db=0)
# Lua脚本:原子性地增加计数器并返回新值
script = """
local key = KEYS[1]
local increment = tonumber(ARGV[1])
local current = tonumber(redis.call('get', key) or '0')
current = current + increment
redis.call('set', key, current)
return current
"""
# 加载并执行脚本
result = r.eval(script, 1, 'counter', 5)
print(f"Counter value: {result}")
5.2.4 缓存策略
缓存模式示例:
import redis
import time
r = redis.Redis(host='localhost', port=6379, db=0)
def get_data_from_db(key):
"""模拟从数据库获取数据"""
time.sleep(1) # 模拟数据库延迟
return f"data for {key}"
def get_cached_data(key):
"""带缓存的获取数据"""
# 先从缓存获取
cached = r.get(key)
if cached:
print("Cache hit!")
return cached.decode()
# 缓存未命中,从数据库获取
print("Cache miss!")
data = get_data_from_db(key)
# 设置缓存,过期时间5分钟
r.setex(key, 300, data)
return data
# 第一次调用(缓存未命中)
print(get_cached_data('user:1'))
# 第二次调用(缓存命中)
print(get_cached_data('user:1'))
案例2:Redis缓存实现
import redis
from functools import wraps
from datetime import timedelta
class RedisCache:
def __init__(self, host='localhost', port=6379, db=0):
self.client = redis.Redis(host=host, port=port, db=db)
def cache(self, key_prefix, expire=timedelta(minutes=5)):
"""装饰器:缓存函数结果"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
key = f"{key_prefix}:{args}:{kwargs}"
# 检查缓存
cached = self.client.get(key)
if cached:
return cached.decode()
# 执行函数
result = func(*args, **kwargs)
# 设置缓存
self.client.setex(key, int(expire.total_seconds()), str(result))
return result
return wrapper
return decorator
def invalidate(self, key):
"""删除缓存"""
self.client.delete(key)
def flush_pattern(self, pattern):
"""删除匹配模式的缓存"""
keys = self.client.keys(pattern)
if keys:
self.client.delete(*keys)
# 使用示例
cache = RedisCache()
@cache.cache("user_profile")
def get_user_profile(user_id):
"""获取用户资料(模拟数据库查询)"""
import time
time.sleep(0.5) # 模拟延迟
return f"Profile for user {user_id}"
# 第一次调用
print(get_user_profile(1)) # 缓存未命中
# 第二次调用(缓存命中)
print(get_user_profile(1)) # 从缓存获取
# 失效缓存
cache.invalidate("user_profile:(1,):{}")
5.3 向量数据库Milvus
5.3.1 概念与架构
Milvus核心概念:
- Collection(集合):类似于关系数据库中的表
- Partition(分区):集合的逻辑划分,用于数据管理
- Entity(实体):数据记录,包含向量和标量字段
- Index(索引):加速向量相似度搜索
安装与连接:
# 安装pymilvus
pip install pymilvus
5.3.2 Collection管理
创建Collection:
from pymilvus import MilvusClient, DataType
# 连接Milvus
client = MilvusClient("http://localhost:19530")
# 创建Collection
schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True
)
# 添加字段
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=384)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=500)
# 创建Collection
client.create_collection(
collection_name="documents",
schema=schema,
metric_type="L2" # 欧氏距离
)
5.3.3 Index类型
常用索引类型:
- FLAT:暴力搜索,精度最高,速度最慢
- IVF_FLAT:基于量化的索引,平衡精度和速度
- IVF_SQ8:量化压缩版IVF_FLAT,节省内存
- HNSW:基于图的索引,高维数据性能好
# 创建HNSW索引
client.create_index(
collection_name="documents",
field_name="embedding",
index_type="HNSW",
index_params={
"M": 16, # 每个节点的最大连接数
"efConstruction": 200 # 构建时的搜索范围
}
)
5.3.4 pymilvus API
插入和搜索向量:
import numpy as np
# 插入向量数据
data = [
{"id": 1, "embedding": np.random.rand(384).tolist(), "text": "Document 1"},
{"id": 2, "embedding": np.random.rand(384).tolist(), "text": "Document 2"},
{"id": 3, "embedding": np.random.rand(384).tolist(), "text": "Document 3"},
]
client.insert(collection_name="documents", data=data)
# 向量搜索
query_vector = np.random.rand(384).tolist()
results = client.search(
collection_name="documents",
data=[query_vector],
limit=2,
search_params={"ef": 50} # 搜索时的范围
)
print(results)
案例3:Milvus向量搜索完整示例
from pymilvus import MilvusClient, DataType
import numpy as np
class MilvusVectorDB:
def __init__(self, uri="http://localhost:19530"):
self.client = MilvusClient(uri)
self.collection_name = "embeddings"
def create_collection(self, dim=384):
"""创建向量集合"""
if self.client.has_collection(self.collection_name):
self.client.drop_collection(self.collection_name)
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=True
)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("metadata", DataType.JSON)
self.client.create_collection(
collection_name=self.collection_name,
schema=schema,
metric_type="COSINE" # 余弦相似度
)
# 创建索引
self.client.create_index(
collection_name=self.collection_name,
field_name="embedding",
index_type="HNSW",
index_params={"M": 16, "efConstruction": 200}
)
def insert_vectors(self, vectors, metadatas=None):
"""插入向量数据"""
data = []
for i, vec in enumerate(vectors):
item = {"embedding": vec.tolist()}
if metadatas:
item["metadata"] = metadatas[i]
data.append(item)
return self.client.insert(collection_name=self.collection_name, data=data)
def search_vectors(self, query_vector, top_k=10, filter=None):
"""向量搜索"""
search_params = {"ef": 50}
results = self.client.search(
collection_name=self.collection_name,
data=[query_vector.tolist()],
limit=top_k,
search_params=search_params,
filter=filter
)
return results[0]
def get_entity_by_id(self, entity_id):
"""根据ID获取实体"""
return self.client.get(
collection_name=self.collection_name,
ids=[entity_id]
)
# 使用示例
db = MilvusVectorDB()
db.create_collection(dim=384)
# 插入示例向量
vectors = np.random.rand(100, 384)
metadatas = [{"text": f"Document {i}"} for i in range(100)]
db.insert_vectors(vectors, metadatas)
# 搜索相似向量
query_vec = np.random.rand(384)
results = db.search_vectors(query_vec, top_k=5)
for res in results:
print(f"ID: {res['id']}, Distance: {res['distance']}")
5.4 Elasticsearch
5.4.1 倒排索引
倒排索引原理:
文档1: "Python is a programming language"
文档2: "Elasticsearch is a search engine"
文档3: "Python Elasticsearch integration"
倒排索引:
"python" -> [文档1, 文档3]
"programming" -> [文档1]
"elasticsearch" -> [文档2, 文档3]
"search" -> [文档2]
"engine" -> [文档2]
"integration" -> [文档3]
创建索引:
from elasticsearch import Elasticsearch
# 连接ES
es = Elasticsearch("http://localhost:9200")
# 创建索引
mapping = {
"mappings": {
"properties": {
"title": {"type": "text"},
"content": {"type": "text"},
"author": {"type": "keyword"},
"published_date": {"type": "date"}
}
}
}
es.indices.create(index="articles", body=mapping)
5.4.2 DSL查询
各种查询类型:
# 全文搜索
query = {
"query": {
"match": {
"content": "data storage"
}
}
}
result = es.search(index="articles", body=query)
# 短语搜索
query = {
"query": {
"match_phrase": {
"content": "machine learning"
}
}
}
# 多字段搜索
query = {
"query": {
"multi_match": {
"query": "python",
"fields": ["title", "content"]
}
}
}
# 布尔查询
query = {
"query": {
"bool": {
"must": [{"match": {"content": "data"}}],
"filter": [{"term": {"author": "John"}}],
"must_not": [{"match": {"content": "deprecated"}}]
}
}
}
5.4.3 中文分词
使用中文分词器:
# 创建带中文分词的索引
mapping = {
"settings": {
"analysis": {
"analyzer": {
"chinese_analyzer": {
"tokenizer": "ik_max_word",
"filter": ["lowercase"]
}
}
}
},
"mappings": {
"properties": {
"title": {"type": "text", "analyzer": "chinese_analyzer"},
"content": {"type": "text", "analyzer": "chinese_analyzer"}
}
}
}
es.indices.create(index="chinese_docs", body=mapping)
5.4.4 聚合操作
聚合查询示例:
# 按作者分组统计
query = {
"aggs": {
"authors": {
"terms": {
"field": "author.keyword",
"size": 10
}
}
}
}
# 日期范围聚合
query = {
"aggs": {
"publication_over_time": {
"date_histogram": {
"field": "published_date",
"calendar_interval": "month"
}
}
}
}
# 平均值聚合
query = {
"aggs": {
"avg_score": {
"avg": {
"field": "score"
}
}
}
}
案例4:Elasticsearch全文搜索
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
class ESClient:
def __init__(self, host="localhost", port=9200):
self.client = Elasticsearch(f"http://{host}:{port}")
def create_index(self, index_name, mapping=None):
"""创建索引"""
if not self.client.indices.exists(index=index_name):
if mapping:
self.client.indices.create(index=index_name, body=mapping)
else:
self.client.indices.create(index=index_name)
def index_documents(self, index_name, documents):
"""批量索引文档"""
actions = [
{
"_index": index_name,
"_id": doc.get("id"),
"_source": doc
}
for doc in documents
]
bulk(self.client, actions)
def search(self, index_name, query, size=10):
"""执行搜索"""
return self.client.search(index=index_name, body=query, size=size)
def highlight_search(self, index_name, query_text, fields=["title", "content"]):
"""带高亮的搜索"""
query = {
"query": {
"multi_match": {
"query": query_text,
"fields": fields
}
},
"highlight": {
"fields": {field: {} for field in fields}
}
}
return self.client.search(index=index_name, body=query)
# 使用示例
es = ESClient()
# 创建索引
mapping = {
"mappings": {
"properties": {
"title": {"type": "text"},
"content": {"type": "text"},
"author": {"type": "keyword"}
}
}
}
es.create_index("documents", mapping)
# 索引文档
documents = [
{"id": "1", "title": "Python基础", "content": "Python是一种高级编程语言", "author": "John"},
{"id": "2", "title": "Elasticsearch指南", "content": "ES是一个分布式搜索引擎", "author": "Alice"},
{"id": "3", "title": "机器学习入门", "content": "机器学习是人工智能的一个分支", "author": "Bob"}
]
es.index_documents("documents", documents)
# 搜索
results = es.highlight_search("documents", "编程语言")
for hit in results["hits"]["hits"]:
print(f"Title: {hit['_source']['title']}")
print(f"Highlight: {hit.get('highlight', {})}")
5.5 Embedding模型
5.5.1 文本嵌入原理
文本嵌入流程:
文本输入 → Tokenizer分词 → 嵌入层 → 向量输出
例如:
输入: "Hello world"
输出: [0.123, -0.456, 0.789, ...] (维度通常为768或384)
常用Embedding模型:
- Sentence-BERT:专注于句子级嵌入
- OpenAI Embeddings:text-embedding-ada-002
- BGE:中文开源模型
- MiniLM:轻量级模型
5.5.2 维度选择
维度选择考虑因素:
| 因素 | 说明 |
|---|---|
| 模型能力 | 更高维度通常表示更强的表达能力 |
| 存储成本 | 高维度向量占用更多存储空间 |
| 搜索速度 | 维度越高,相似度计算越慢 |
| 应用场景 | 语义搜索通常用384-768维 |
5.5.3 批量编码
使用Sentence-BERT批量编码:
from sentence_transformers import SentenceTransformer
import numpy as np
# 加载模型
model = SentenceTransformer('all-MiniLM-L6-v2')
# 批量编码文本
sentences = [
"Python是一种高级编程语言",
"机器学习是人工智能的分支",
"深度学习是机器学习的子领域",
"神经网络由多个神经元组成"
]
# 生成嵌入向量
embeddings = model.encode(sentences)
print(f"向量形状: {embeddings.shape}") # (4, 384)
print(f"第一个向量: {embeddings[0][:5]}")
案例5:Embedding服务封装
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List
class EmbeddingService:
def __init__(self, model_name='all-MiniLM-L6-v2'):
self.model = SentenceTransformer(model_name)
self.dim = self.model.get_sentence_embedding_dimension()
def encode(self, texts: List[str], batch_size=32) -> np.ndarray:
"""批量编码文本"""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings = self.model.encode(batch)
embeddings.append(batch_embeddings)
return np.vstack(embeddings)
def encode_single(self, text: str) -> np.ndarray:
"""编码单条文本"""
return self.model.encode(text)
def compute_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算余弦相似度"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def rank_by_similarity(self, query_vec: np.ndarray, doc_vecs: np.ndarray) -> List[int]:
"""按相似度排序"""
similarities = [self.compute_similarity(query_vec, doc_vec) for doc_vec in doc_vecs]
return np.argsort(similarities)[::-1].tolist()
# 使用示例
service = EmbeddingService()
# 编码文档
documents = [
"Python编程入门指南",
"机器学习算法详解",
"深度学习实战",
"数据结构与算法"
]
doc_vectors = service.encode(documents)
# 查询
query = "Python机器学习"
query_vec = service.encode_single(query)
# 计算相似度
similarities = [service.compute_similarity(query_vec, dv) for dv in doc_vectors]
for doc, sim in zip(documents, similarities):
print(f"{doc}: {sim:.4f}")
5.6 相似度算法
5.6.1 余弦相似度
余弦相似度公式:
sim(A, B) = (A · B) / (||A|| × ||B||)
实现:
import numpy as np
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算余弦相似度"""
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2)
# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(cosine_similarity(vec1, vec2)) # 0.9746
5.6.2 欧氏距离
欧氏距离公式:
distance(A, B) = sqrt(sum((A_i - B_i)^2))
实现:
def euclidean_distance(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算欧氏距离"""
return np.linalg.norm(vec1 - vec2)
# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(euclidean_distance(vec1, vec2)) # 5.196
5.6.3 内积
内积相似度:
def dot_product(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算内积"""
return np.dot(vec1, vec2)
# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(dot_product(vec1, vec2)) # 32
5.6.4 HNSW索引原理
HNSW(Hierarchical Navigable Small Worlds):
层级结构示意:
Level 2: A ---- B
\ /
\ /
Level 1: C -- D -- E
| | |
| | |
Level 0: F -- G -- H -- I
搜索过程:
1. 从顶层入口点开始
2. 在每层找到最近的邻居
3. 向下层移动,重复搜索
4. 在底层进行精确搜索
案例6:多种相似度算法对比
import numpy as np
class SimilarityCalculator:
@staticmethod
def cosine(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""余弦相似度"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
if norm == 0:
return 0.0
return np.dot(vec1, vec2) / norm
@staticmethod
def euclidean(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""欧氏距离(归一化)"""
dist = np.linalg.norm(vec1 - vec2)
return 1 / (1 + dist) # 转换为相似度
@staticmethod
def dot(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""内积(归一化)"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
@staticmethod
def manhattan(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""曼哈顿距离(归一化)"""
dist = np.sum(np.abs(vec1 - vec2))
return 1 / (1 + dist)
@staticmethod
def jaccard(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""Jaccard相似度(适用于二进制向量)"""
intersection = np.sum(np.logical_and(vec1, vec2))
union = np.sum(np.logical_or(vec1, vec2))
if union == 0:
return 0.0
return intersection / union
# 使用示例
calculator = SimilarityCalculator()
vec1 = np.array([1, 2, 3, 4, 5])
vec2 = np.array([2, 3, 4, 5, 6])
vec3 = np.array([10, 20, 30, 40, 50])
print(f"余弦相似度 (v1, v2): {calculator.cosine(vec1, vec2):.4f}")
print(f"欧氏相似度 (v1, v2): {calculator.euclidean(vec1, vec2):.4f}")
print(f"余弦相似度 (v1, v3): {calculator.cosine(vec1, vec3):.4f}")
5.7 混合检索
5.7.1 向量+关键词检索
混合检索架构:
查询
│
┌─────────────┼─────────────┐
▼ ▼ ▼
向量检索 关键词检索 语义检索
│ │ │
▼ ▼ ▼
Milvus Elasticsearch Embedding
│ │ │
└─────────────┼─────────────┘
▼
结果融合
│
▼
最终结果
5.7.2 RRF(Reciprocal Rank Fusion)
RRF公式:
score(d) = sum(1 / (k + rank_i(d)))
其中:
- k: 常数(通常取60)
- rank_i(d): 结果d在第i个检索器中的排名
实现:
def reciprocal_rank_fusion(rankings: list, k: int = 60) -> dict:
"""计算RRF融合分数"""
scores = {}
for rank_list in rankings:
for idx, doc_id in enumerate(rank_list):
if doc_id not in scores:
scores[doc_id] = 0
scores[doc_id] += 1 / (k + idx + 1) # 排名从1开始
# 按分数排序
sorted_docs = sorted(scores.items(), key=lambda x: -x[1])
return dict(sorted_docs)
# 示例
vector_ranking = ["doc1", "doc3", "doc2", "doc4"]
keyword_ranking = ["doc2", "doc1", "doc5", "doc3"]
fused = reciprocal_rank_fusion([vector_ranking, keyword_ranking])
print(fused)
# {'doc1': 0.0325, 'doc2': 0.0325, 'doc3': 0.0308, 'doc4': 0.0164, 'doc5': 0.0164}
案例7:混合检索实现
from typing import List, Dict, Tuple
import numpy as np
class HybridRetriever:
def __init__(self, vector_db, keyword_db, embedding_service):
self.vector_db = vector_db
self.keyword_db = keyword_db
self.embedding_service = embedding_service
def vector_search(self, query: str, top_k: int = 10) -> List[str]:
"""向量检索"""
query_vec = self.embedding_service.encode_single(query)
results = self.vector_db.search_vectors(query_vec, top_k=top_k)
return [str(res["id"]) for res in results]
def keyword_search(self, query: str, top_k: int = 10) -> List[str]:
"""关键词检索"""
results = self.keyword_db.search(query, size=top_k)
return [hit["_id"] for hit in results["hits"]["hits"]]
def reciprocal_rank_fusion(self, rankings: List[List[str]], k: int = 60) -> Dict[str, float]:
"""RRF融合"""
scores = {}
for rank_list in rankings:
for idx, doc_id in enumerate(rank_list):
scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + idx + 1)
return dict(sorted(scores.items(), key=lambda x: -x[1]))
def hybrid_search(self, query: str, top_k: int = 10, rrf_k: int = 60) -> List[Tuple[str, float]]:
"""混合检索"""
# 获取两个检索器的结果
vector_results = self.vector_search(query, top_k)
keyword_results = self.keyword_search(query, top_k)
# 融合结果
fused = self.reciprocal_rank_fusion([vector_results, keyword_results], k=rrf_k)
# 返回前top_k结果
return list(fused.items())[:top_k]
# 使用示例(伪代码)
# vector_db = MilvusVectorDB()
# keyword_db = ESClient()
# embedding_service = EmbeddingService()
# retriever = HybridRetriever(vector_db, keyword_db, embedding_service)
# results = retriever.hybrid_search("Python机器学习", top_k=5)
# for doc_id, score in results:
# print(f"Document: {doc_id}, Score: {score:.4f}")
5.8 数据管道设计
5.8.1 ETL流程
ETL架构:
数据源 → 抽取(Extract) → 转换(Transform) → 加载(Load) → 数据仓库
详细流程:
1. Extract: 从数据库、API、文件等抽取原始数据
2. Transform: 清洗、转换、标准化数据
3. Load: 加载到目标存储系统
ETL示例:
import pandas as pd
import psycopg2
from sqlalchemy import create_engine
class ETLPipeline:
def __init__(self, source_conn_str, target_conn_str):
self.source_engine = create_engine(source_conn_str)
self.target_engine = create_engine(target_conn_str)
def extract(self, query: str) -> pd.DataFrame:
"""抽取数据"""
return pd.read_sql(query, self.source_engine)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""转换数据"""
# 去除空值
df = df.dropna()
# 标准化日期格式
df['created_at'] = pd.to_datetime(df['created_at']).dt.strftime('%Y-%m-%d')
# 添加处理时间
df['processed_at'] = pd.Timestamp.now()
return df
def load(self, df: pd.DataFrame, table_name: str):
"""加载数据"""
df.to_sql(table_name, self.target_engine, if_exists='append', index=False)
def run(self, query: str, table_name: str):
"""运行完整ETL流程"""
df = self.extract(query)
df = self.transform(df)
self.load(df, table_name)
print(f"Loaded {len(df)} records")
# 使用示例
pipeline = ETLPipeline(
source_conn_str="postgresql://user:pass@localhost/source_db",
target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run("SELECT * FROM users", "users_staging")
5.8.2 增量更新
增量更新策略:
class IncrementalETL(ETLPipeline):
def __init__(self, source_conn_str, target_conn_str):
super().__init__(source_conn_str, target_conn_str)
self.last_sync_time = None
def get_last_sync_time(self, table_name: str) -> str:
"""获取上次同步时间"""
query = f"SELECT MAX(processed_at) FROM {table_name}"
with self.target_engine.connect() as conn:
result = conn.execute(query).fetchone()
return result[0] if result[0] else "1970-01-01"
def incremental_extract(self, table_name: str) -> pd.DataFrame:
"""增量抽取"""
self.last_sync_time = self.get_last_sync_time(table_name)
query = f"""
SELECT * FROM {table_name}
WHERE updated_at > '{self.last_sync_time}'
"""
return self.extract(query)
def run_incremental(self, source_table: str, target_table: str):
"""运行增量ETL"""
df = self.incremental_extract(source_table)
if not df.empty:
df = self.transform(df)
self.load(df, target_table)
print(f"Incrementally loaded {len(df)} records")
else:
print("No new records to process")
# 使用示例
pipeline = IncrementalETL(
source_conn_str="postgresql://user:pass@localhost/source_db",
target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run_incremental("users", "users_staging")
5.8.3 一致性保障
数据一致性策略:
import contextlib
class TransactionalETL(ETLPipeline):
def __init__(self, source_conn_str, target_conn_str):
super().__init__(source_conn_str, target_conn_str)
@contextlib.contextmanager
def transaction(self):
"""事务上下文管理器"""
conn = self.target_engine.connect()
trans = conn.begin()
try:
yield conn
trans.commit()
print("Transaction committed")
except Exception as e:
trans.rollback()
print(f"Transaction rolled back: {e}")
raise
finally:
conn.close()
def run_with_transaction(self, query: str, table_name: str):
"""带事务的ETL"""
df = self.extract(query)
df = self.transform(df)
with self.transaction():
# 先删除可能存在的重复数据
with self.target_engine.connect() as conn:
conn.execute(f"DELETE FROM {table_name} WHERE id IN ({','.join(map(str, df['id'].tolist()))})")
self.load(df, table_name)
print(f"Successfully loaded {len(df)} records")
# 使用示例
pipeline = TransactionalETL(
source_conn_str="postgresql://user:pass@localhost/source_db",
target_conn_str="postgresql://user:pass@localhost/target_db"
)
pipeline.run_with_transaction("SELECT * FROM orders WHERE status = 'completed'", "completed_orders")
案例8:完整数据管道
import pandas as pd
from sqlalchemy import create_engine
from datetime import datetime, timedelta
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataPipeline:
def __init__(self, config):
self.config = config
self.source_engine = create_engine(config["source"])
self.target_engine = create_engine(config["target"])
def extract(self, table_name: str, days: int = 7) -> pd.DataFrame:
"""抽取最近N天的数据"""
cutoff_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d')
query = f"""
SELECT * FROM {table_name}
WHERE created_at >= '{cutoff_date}'
"""
logger.info(f"Extracting data from {table_name} since {cutoff_date}")
return pd.read_sql(query, self.source_engine)
def validate(self, df: pd.DataFrame) -> bool:
"""数据验证"""
required_columns = ['id', 'name', 'created_at']
if not all(col in df.columns for col in required_columns):
logger.error("Missing required columns")
return False
if df['id'].duplicated().any():
logger.error("Duplicate IDs found")
return False
logger.info("Data validation passed")
return True
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""数据转换"""
logger.info("Transforming data...")
# 清洗
df = df.drop_duplicates(subset=['id'])
# 类型转换
df['created_at'] = pd.to_datetime(df['created_at'])
# 派生字段
df['year_month'] = df['created_at'].dt.strftime('%Y-%m')
df['processing_time'] = datetime.now()
return df
def load(self, df: pd.DataFrame, table_name: str, mode: str = 'append'):
"""加载数据"""
logger.info(f"Loading {len(df)} records to {table_name}")
df.to_sql(table_name, self.target_engine, if_exists=mode, index=False)
def monitor(self, df: pd.DataFrame, step: str):
"""监控数据质量"""
stats = {
"step": step,
"record_count": len(df),
"columns": list(df.columns),
"null_counts": df.isnull().sum().to_dict(),
"timestamp": datetime.now().isoformat()
}
logger.info(f"Monitoring stats: {stats}")
def run(self, table_name: str):
"""运行完整管道"""
try:
# 抽取
df = self.extract(table_name)
self.monitor(df, "extract")
# 验证
if not self.validate(df):
return
# 转换
df = self.transform(df)
self.monitor(df, "transform")
# 加载
self.load(df, table_name)
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Pipeline failed: {e}", exc_info=True)
raise
# 使用配置
config = {
"source": "postgresql://user:pass@localhost/source_db",
"target": "postgresql://user:pass@localhost/data_warehouse"
}
# 运行管道
pipeline = DataPipeline(config)
pipeline.run("user_activity")
小结
本章介绍了数据存储与检索的核心内容,包括:
- PostgreSQL:数据模型设计、索引优化、JSON类型、全文搜索
- Redis:数据结构、发布订阅、Lua脚本、缓存策略
- Milvus向量数据库:Collection管理、Index类型、pymilvus API
- Elasticsearch:倒排索引、DSL查询、中文分词、聚合操作
- Embedding模型:文本嵌入原理、维度选择、批量编码
- 相似度算法:余弦相似度、欧氏距离、内积、HNSW索引原理
- 混合检索:向量+关键词检索、RRF得分融合策略
- 数据管道设计:ETL流程、增量更新、一致性保障
掌握这些技术对于构建现代化的AI应用,特别是RAG系统,至关重要。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)