手写 AI 缓存系统:从零实现语义缓存与结果复用
一、为什么 AI 系统需要缓存
在实际的 AI 应用中,重复计算是一个普遍存在的性能瓶颈。以 RAG 系统为例,当多个用户询问相似问题(如"公司的薪酬政策是什么"和"请说明薪酬规定"),如果每次都重新执行嵌入向量检索、大模型推理等全套流程,不仅响应延迟高,还会产生大量 API 调用开销。
缓存的核心价值体现在三个维度:
降低延迟:缓存命中时直接返回结果,避免大模型推理的 3-10 秒等待。对于实时对话系统,这直接决定了用户体验的流畅度。
节省成本:LLM API 按 token 计费。一个 2000 token 的查询响应,在 GPT-4 上约需 0.06 元。即使每天只有 1000 次重复查询,一年就能省下 2 万元以上。
系统保护:缓存吸收突发流量峰值,防止后端模型服务过载。这在并发高峰期尤为重要。
本文将从零构建一个完整的 AI 缓存系统,覆盖精确缓存、语义缓存、LRU 淘汰、TTL 过期等核心机制。全部代码手写实现,不依赖任何缓存框架。
二、缓存系统架构总览
一个完整的 AI 缓存系统包含以下层次:
用户请求
↓
[ 缓存层 ]
├── 一级缓存:精确匹配(内存哈希表)
├── 二级缓存:语义匹配(向量相似度检索)
└── 三级缓存:持久化存储(可选 Redis/磁盘)
↓
[ 缓存策略引擎 ]
├── LRU 淘汰策略
├── TTL 过期策略
└── 预热与写穿透策略
↓
[ 实际计算 ]
└── LLM 推理 / RAG 检索
第一层精确缓存用于完全相同的查询,毫秒级响应。第二层语义缓存用于含义相似但表述不同的查询,需要向量化 + 相似度检索。第三层持久化缓存用于系统重启后恢复热点数据。
三、基础数据结构:双向链表 + 哈希表
LRU 缓存的核心数据结构是「双向链表 + 哈希表」组合。双向链表维护访问顺序,哈希表提供 O(1) 查找。
from __future__ import annotations
import hashlib
import json
import time
import threading
from typing import Any, Optional, Callable
from dataclasses import dataclass, field
import numpy as np
@dataclass
class CacheNode:
"""缓存节点,双向链表的基本单元"""
key: str
value: Any
prev: Optional[CacheNode] = None
next: Optional[CacheNode] = None
size: int = 0 # 缓存项大小(字节)
hits: int = 0 # 访问次数
created_at: float = 0.0
accessed_at: float = 0.0
class LRUCache:
"""
基于双向链表 + 哈希表的 LRU 缓存
特点:O(1) 查找、插入、删除
"""
def __init__(self, capacity: int = 1000, max_memory_mb: int = 256):
self.capacity = capacity
self.max_memory = max_memory_mb * 1024 * 1024 # 转为字节
self.cache: dict[str, CacheNode] = {}
self.head = CacheNode(key="__head__", value=None) # 伪头节点
self.tail = CacheNode(key="__tail__", value=None) # 伪尾节点
self.head.next = self.tail
self.tail.prev = self.head
self.current_memory = 0
self.lock = threading.RLock() # 线程安全
self._init_metrics()
def _init_metrics(self):
self.metrics = {
"hits": 0,
"misses": 0,
"evictions": 0,
"total_lookups": 0,
}
def _remove_node(self, node: CacheNode):
"""从链表中移除节点(不操作哈希表)"""
prev_node = node.prev
next_node = node.next
prev_node.next = next_node
next_node.prev = prev_node
def _add_to_head(self, node: CacheNode):
"""将节点插入到头部(最近使用)"""
node.prev = self.head
node.next = self.head.next
self.head.next.prev = node
self.head.next = node
def _move_to_head(self, node: CacheNode):
"""将已存在节点移到头部"""
self._remove_node(node)
self._add_to_head(node)
def _evict_tail(self) -> Optional[CacheNode]:
"""移除最久未使用的节点(尾部)"""
if self.tail.prev == self.head:
return None
node = self.tail.prev
self._remove_node(node)
del self.cache[node.key]
self.current_memory -= node.size
self.metrics["evictions"] += 1
return node
def get(self, key: str) -> Optional[Any]:
"""获取缓存项,同时更新访问顺序"""
with self.lock:
self.metrics["total_lookups"] += 1
node = self.cache.get(key)
if node is None:
self.metrics["misses"] += 1
return None
self._move_to_head(node)
node.hits += 1
node.accessed_at = time.time()
self.metrics["hits"] += 1
return node.value
def put(self, key: str, value: Any, cost: int = 0):
"""插入或更新缓存项"""
with self.lock:
now = time.time()
node = self.cache.get(key)
if node is not None:
# 更新已有节点
self.current_memory -= node.size
node.value = value
node.size = cost
node.accessed_at = now
node.created_at = now if node.created_at == 0 else node.created_at
self._move_to_head(node)
self.current_memory += cost
return
# 新建节点
node = CacheNode(
key=key,
value=value,
size=cost,
hits=0,
created_at=now,
accessed_at=now,
)
# 检查容量和内存限制
while (
len(self.cache) >= self.capacity
or self.current_memory + cost > self.max_memory
):
if not self._evict_tail():
break
self.cache[key] = node
self._add_to_head(node)
self.current_memory += cost
def invalidate(self, key: str):
"""主动失效指定缓存项"""
with self.lock:
node = self.cache.pop(key, None)
if node:
self.current_memory -= node.size
self._remove_node(node)
def clear(self):
"""清空所有缓存"""
with self.lock:
self.cache.clear()
self.head.next = self.tail
self.tail.prev = self.head
self.current_memory = 0
self._init_metrics()
def hit_rate(self) -> float:
"""缓存命中率"""
total = self.metrics["hits"] + self.metrics["misses"]
return self.metrics["hits"] / total if total > 0 else 0.0
def __len__(self):
return len(self.cache)
def __contains__(self, key: str):
with self.lock:
return key in self.cache
这段代码的核心设计思想是使用哨兵节点(dummy head/tail)消除边界条件判断。在插入和删除操作中,头尾虚拟节点让代码不需要单独处理空链表情况,显著减少 bug 来源。
threading.RLock(可重入锁)保证线程安全,因为在 put 方法内可能触发 _evict_tail,如果使用普通锁 Lock,同一线程内再次获取同一把锁会造成死锁。
四、精确缓存:查询键的标准化
精确缓存要求完全相同的关键字才能命中。但用户查询可能有细微差异(多余空格、大小写、标点符号),我们需要先对查询进行键标准化:
class ExactCache:
"""
精确缓存层:完全匹配查询
支持查询键标准化和 TTL 过期
"""
def __init__(self, capacity: int = 10000, default_ttl: int = 3600):
self.lru = LRUCache(capacity=capacity)
self.default_ttl = default_ttl
self.ttl_map: dict[str, float] = {} # key → 过期时间戳
@staticmethod
def normalize_query(query: str) -> str:
"""标准化查询字符串,提高命中率"""
# 折叠空白
normalized = " ".join(query.split())
# 转小写
normalized = normalized.lower().strip()
# 统一标点
normalized = normalized.replace("?", "?").replace("!", "!")
# 移除末尾句号
normalized = normalized.rstrip("。.")
return normalized
def make_key(self, query: str, context: Optional[dict] = None) -> str:
"""生成缓存键:查询 + 上下文哈希"""
normalized = self.normalize_query(query)
if context:
context_str = json.dumps(context, sort_keys=True)
context_hash = hashlib.md5(context_str.encode()).hexdigest()[:8]
return f"{normalized}::ctx:{context_hash}"
return normalized
def get(self, query: str, context: Optional[dict] = None) -> Optional[Any]:
key = self.make_key(query, context)
# 检查 TTL
if key in self.ttl_map and time.time() > self.ttl_map[key]:
self.lru.invalidate(key)
del self.ttl_map[key]
return None
return self.lru.get(key)
def set(
self,
query: str,
value: Any,
context: Optional[dict] = None,
ttl: Optional[int] = None,
):
key = self.make_key(query, context)
cost = len(json.dumps(value, default=str).encode("utf-8"))
self.lru.put(key, value, cost=cost)
self.ttl_map[key] = time.time() + (ttl or self.default_ttl)
def hit_rate(self) -> float:
return self.lru.hit_rate()
键标准化中," ".join(query.split()) 是非常实用的技巧:它同时完成了去除首尾空格和折叠中间连续空格两件事,比 strip() + 正则替换更简洁高效。
make_key 中的 context 参数支持按场景隔离缓存。例如,同一个问题"今天天气如何"在不同城市上下文中应该有不同的缓存。
五、语义缓存:向量相似度匹配
语义缓存是 AI 缓存的精髓所在。当用户问"如何重置密码"和"忘了密码怎么办"时,虽然字符串不同,但语义高度相似,应该复用同一份缓存结果。
实现语义缓存包含三个步骤:
- 查询向量化:将文本转为高维向量
- 相似度检索:在已有缓存向量库中查找最近邻
- 相似度阈值判断:高于阈值则视为语义匹配
@dataclass
class SemanticEntry:
"""语义缓存条目"""
query: str # 原始查询
key: str # 精确缓存键
embedding: np.ndarray # 嵌入向量
result: Any # 缓存结果
access_count: int = 0
last_access: float = 0.0
class SemanticCache:
"""
语义缓存层:基于向量相似度匹配语义相近的查询
采用扁平列表 + 暴力搜索(适合中小规模缓存)
"""
def __init__(
self,
embed_func: Callable[[str], list[float]],
similarity_threshold: float = 0.92,
max_entries: int = 5000,
):
self.embed_func = embed_func
self.threshold = similarity_threshold
self.max_entries = max_entries
self.entries: dict[str, SemanticEntry] = {} # key → entry
self.lock = threading.RLock()
@staticmethod
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""余弦相似度计算"""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def find_similar(self, query: str) -> Optional[tuple[str, Any, float]]:
"""
在已有缓存中查找语义相似的查询
返回 (原查询key, 缓存结果, 相似度分数)
"""
with self.lock:
if not self.entries:
return None
query_vec = np.array(self.embed_func(query), dtype=np.float32)
best_key = None
best_sim = 0.0
best_result = None
for key, entry in self.entries.items():
sim = self.cosine_similarity(query_vec, entry.embedding)
if sim > best_sim:
best_sim = sim
best_key = key
best_result = entry.result
if best_sim >= self.threshold and best_key is not None:
self.entries[best_key].access_count += 1
self.entries[best_key].last_access = time.time()
return (best_key, best_result, best_sim)
return None
def add(self, query: str, result: Any, exact_key: str):
"""添加新的语义缓存条目"""
with self.lock:
if len(self.entries) >= self.max_entries:
# 淘汰最久未访问的条目
oldest_key = min(
self.entries, key=lambda k: self.entries[k].last_access
)
del self.entries[oldest_key]
embedding = np.array(
self.embed_func(query), dtype=np.float32
)
self.entries[exact_key] = SemanticEntry(
query=query,
key=exact_key,
embedding=embedding,
result=result,
last_access=time.time(),
)
def evict(self, exact_key: str):
"""移除指定缓存(与精确缓存保持同步)"""
with self.lock:
self.entries.pop(exact_key, None)
这里的关键参数是 similarity_threshold(相似度阈值)。0.92 是一个经过实践检验的起始值——太低会误匹配(把"如何写 Python 代码"和"如何写 Java 代码"当成相同问题),太高则退化为精确匹配。
余弦相似度的计算实现中,需要注意零向量保护:np.linalg.norm 返回 0 时做除法会得到 NaN,导致后续比较逻辑异常。
六、缓存策略引擎:TTL + 自适应淘汰
将精确缓存和语义缓存整合为一个统一的缓存系统,再加入 TTL 过期管理和自适应淘汰策略:
class AICacheSystem:
"""
AI 缓存系统:整合精确缓存 + 语义缓存 + 策略管理
完整的读写穿透接口
"""
def __init__(
self,
embed_func: Callable[[str], list[float]],
exact_capacity: int = 10000,
semantic_threshold: float = 0.92,
default_ttl: int = 3600,
enable_semantic: bool = True,
):
self.exact_cache = ExactCache(capacity=exact_capacity, default_ttl=default_ttl)
self.semantic_cache = SemanticCache(
embed_func=embed_func,
similarity_threshold=semantic_threshold,
) if enable_semantic else None
self.enable_semantic = enable_semantic
self.default_ttl = default_ttl
self.embed_func = embed_func
self.lock = threading.RLock()
self.stats = {
"exact_hits": 0,
"semantic_hits": 0,
"misses": 0,
}
def get(
self,
query: str,
context: Optional[dict] = None,
) -> Optional[Any]:
"""
两级缓存读取:
1. 先查精确缓存
2. 未命中则查语义缓存
3. 都未命中返回 None
"""
# 第一级:精确缓存
result = self.exact_cache.get(query, context)
if result is not None:
self.stats["exact_hits"] += 1
return result
# 第二级:语义缓存
if self.enable_semantic:
similar = self.semantic_cache.find_similar(query)
if similar is not None:
similar_key, result, similarity = similar
self.stats["semantic_hits"] += 1
# 将语义命中的查询也写入精确缓存,下次秒级响应
self.exact_cache.set(query, result, context, ttl=600)
return result
self.stats["misses"] += 1
return None
def set(
self,
query: str,
value: Any,
context: Optional[dict] = None,
ttl: Optional[int] = None,
):
"""写入缓存(两级同步写入)"""
exact_key = self.exact_cache.make_key(query, context)
# 写入精确缓存
self.exact_cache.set(query, value, context, ttl=ttl)
# 写入语义缓存
if self.enable_semantic:
self.semantic_cache.add(query, value, exact_key)
def get_or_compute(
self,
query: str,
compute_fn: Callable[[], Any],
context: Optional[dict] = None,
ttl: Optional[int] = None,
) -> Any:
"""
缓存穿透读取模式:
命中 → 返回缓存值
未命中 → 调用 compute_fn → 缓存结果 → 返回
"""
result = self.get(query, context)
if result is not None:
return result
result = compute_fn()
self.set(query, result, context, ttl=ttl)
return result
def summary(self) -> dict:
"""缓存系统运行统计"""
total = sum(self.stats.values())
return {
"exact_hits": self.stats["exact_hits"],
"semantic_hits": self.stats["semantic_hits"],
"misses": self.stats["misses"],
"total": total,
"hit_rate": (self.stats["exact_hits"] + self.stats["semantic_hits"])
/ total if total > 0 else 0.0,
"semantic_rate": self.stats["semantic_hits"] / total if total > 0 else 0.0,
"exact_hit_rate": self.exact_cache.hit_rate(),
"exact_entries": len(self.exact_cache.lru),
"semantic_entries": len(self.semantic_cache.entries)
if self.semantic_cache else 0,
}
get_or_compute 是缓存系统的核心接口模式。调用方只需要传入「查询」和「计算函数」,缓存系统自动处理命中判断和数据写入。这种模式也称为缓存穿透保护——避免在缓存未命中时多个并发请求同时触发昂贵的计算。
七、嵌入向量服务的封装
语义缓存依赖嵌入向量生成,我们封装一个通用的嵌入服务和分词器:
class EmbeddingService:
"""
嵌入向量生成服务
支持本地模型和远程 API 两种模式
"""
def __init__(self, mode: str = "api", api_url: str = None, api_key: str = None):
self.mode = mode
self.api_url = api_url
self.api_key = api_key
self._model = None
self._tokenizer = None
if mode == "local":
self._init_local_model()
def _init_local_model(self):
"""初始化本地嵌入模型(以 sentence-transformers 为例)"""
try:
from sentence_transformers import SentenceTransformer
# 使用轻量级多语言模型
self._model = SentenceTransformer(
"paraphrase-multilingual-MiniLM-L12-v2"
)
except ImportError:
raise ImportError(
"需要安装 sentence-transformers: pip install sentence-transformers"
)
def embed(self, text: str) -> list[float]:
"""生成文本的嵌入向量"""
if self.mode == "local" and self._model:
vec = self._model.encode(text, normalize_embeddings=True)
return vec.tolist()
elif self.mode == "api":
return self._api_embed(text)
else:
raise ValueError(f"Unknown mode: {self.mode}")
def _api_embed(self, text: str) -> list[float]:
"""通过 API 生成嵌入向量"""
import requests
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "text-embedding-v2",
"input": text,
}
resp = requests.post(
f"{self.api_url}/embeddings",
headers=headers,
json=payload,
timeout=10,
)
resp.raise_for_status()
data = resp.json()
return data["data"][0]["embedding"]
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""批量生成嵌入向量"""
if self.mode == "local" and self._model:
vecs = self._model.encode(texts, normalize_embeddings=True)
return vecs.tolist()
else:
return [self.embed(t) for t in texts]
class QueryTokenizer:
"""
查询分词与特征提取
用于缓存键生成和语义分析的预处理
"""
def __init__(self, max_length: int = 512):
self.max_length = max_length
def tokenize(self, text: str) -> dict:
"""对查询做基础分词,返回统计特征"""
import re
text = text[:self.max_length]
words = re.findall(r'\w+', text.lower())
return {
"word_count": len(words),
"char_count": len(text),
"has_question": "?" in text or "?" in text,
"first_word": words[0] if words else "",
"keyword_set": set(words),
}
def jaccard_similarity(self, query_a: str, query_b: str) -> float:
"""
Jaccard 相似度——作为语义缓存的快速预过滤
如果关键词完全不重叠,大概率语义也不同
"""
feat_a = self.tokenize(query_a)
feat_b = self.tokenize(query_b)
set_a = feat_a["keyword_set"]
set_b = feat_b["keyword_set"]
intersection = len(set_a & set_b)
union = len(set_a | set_b)
return intersection / union if union > 0 else 0.0
嵌入服务支持本地和远程两种模式。本地模式使用 paraphrase-multilingual-MiniLM-L12-v2——一个 384 维的轻量多语言模型,66MB 内存占用,CPU 上单次编码只需 5-10ms,非常适合实时缓存场景。
QueryTokenizer 中的 Jaccard 相似度可作为语义缓存的快速预过滤——如果两个查询的关键词完全不重叠(Jaccard = 0),大概率语义也不同,无需浪费嵌入向量的相似度计算。
八、写穿透与缓存预热
生产环境中,缓存数据和真实数据可能不一致。写穿透(Write-Through)策略保证每次写入都同步更新缓存:
class WriteThroughCache(AICacheSystem):
"""
写穿透缓存装饰器
每次写入数据时同步更新缓存
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._write_hooks: list[Callable] = []
def on_write(self, hook: Callable[[str, Any], None]):
"""注册写入回调(用于持久化同步)"""
self._write_hooks.append(hook)
def write_through(
self,
key: str,
value: Any,
context: Optional[dict] = None,
ttl: Optional[int] = None,
sync_persistence: bool = True,
):
"""写穿透:写入缓存 + 触发持久化回调"""
self.set(key, value, context, ttl=ttl)
if sync_persistence:
for hook in self._write_hooks:
try:
hook(key, value)
except Exception as e:
print(f"[Cache] Write hook failed: {e}")
def warm_up(self, data: list[tuple[str, Any, Optional[dict]]]):
"""
缓存预热:批量加载热点数据
一般在系统启动时调用
"""
for query, value, context in data:
self.set(query, value, context)
print(
f"[Cache] Warm-up complete: {len(data)} entries loaded"
)
class CachePersistence:
"""缓存持久化:磁盘备份与恢复"""
def __init__(self, filepath: str = "/tmp/ai_cache_backup.json"):
self.filepath = filepath
def save(self, data: dict):
with open(self.filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, default=str)
def load(self) -> dict:
try:
with open(self.filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def periodic_save(self, cache: AICacheSystem, interval: int = 300):
"""定时持久化缓存快照"""
import threading
def _save_loop():
while True:
time.sleep(interval)
snapshot = cache.summary()
self.save({"timestamp": time.time(), "stats": snapshot})
thread = threading.Thread(target=_save_loop, daemon=True)
thread.start()
写穿透确保缓存与后端数据源的一致性。on_write 注册的回调可以对接数据库写入、日志记录或消息队列推送。
预热接口 warm_up 在系统启动时调用,加载前一天的热点数据,避免刚启动时的缓存空窗期。
九、完整使用示例
将以上所有组件组合起来,演示一个真实的 AI 问答缓存场景:
def demo_ai_cache():
"""
AI 缓存系统完整演示
使用一个简单的模拟嵌入函数(真实环境应替换为实际嵌入服务)
"""
import random
# 模拟嵌入函数(演示用)
def mock_embed(text: str) -> list[float]:
random.seed(hash(text) % (2**31))
return [random.random() for _ in range(4)]
# 构建缓存系统
cache = AICacheSystem(
embed_func=mock_embed,
exact_capacity=1000,
semantic_threshold=0.85, # 演示用较低阈值
default_ttl=3600,
)
# 模拟 LLM 调用
def expensive_llm_call(query: str) -> str:
print(f" [LLM] 正在计算: '{query}'")
time.sleep(0.5) # 模拟推理延迟
return f"这是关于「{query}」的回答。"
# 模拟查询序列
queries = [
("公司的年假政策是什么", False),
("Please explain annual leave policy", False),
("年假几天", False),
("年假政策", False), # 应命中语义缓存
("公司的年假政策是什么", False), # 应命中精确缓存
("怎么请假", True), # 不相关,不应命中
]
print("=== AI 缓存系统演示 ===\n")
for i, (query, should_miss) in enumerate(queries):
print(f"查询 {i+1}: '{query}'")
start = time.time()
result = cache.get(query)
if result is not None:
elapsed = time.time() - start
print(f" [缓存] ✅ 命中! ({elapsed:.3f}s)")
else:
print(f" [缓存] ❌ 未命中,执行 LLM 调用...")
result = expensive_llm_call(query)
cache.set(query, result)
elapsed = time.time() - start
print(f" [完成] ({elapsed:.3f}s)")
print(f" 结果: {result}\n")
# 打印统计
print("=== 缓存统计 ===")
stats = cache.summary()
for k, v in stats.items():
print(f" {k}: {v}")
def demo_semantic_matching():
"""
语义缓存匹配演示
展示不同表述如何命中同一个缓存结果
"""
def mock_embed(text: str) -> list[float]:
# 为演示设计:关键词越相似,向量越接近
keywords = ["密码", "重置", "忘记", "登录", "账户", "修改"]
vec = [0.0] * len(keywords)
for i, kw in enumerate(keywords):
if kw in text:
vec[i] = 1.0
# 归一化
norm = sum(v * v for v in vec) ** 0.5
return [v / norm for v in vec] if norm > 0 else vec
cache = AICacheSystem(
embed_func=mock_embed,
semantic_threshold=0.75,
)
# 将原始问题缓存
original = "如何重置密码?"
answer = "步骤:1. 点击登录页"忘记密码" 2. 输入注册邮箱 3. 查收重置链接 4. 设置新密码"
cache.set(original, answer)
# 测试语义相近的查询
test_queries = [
"忘记密码怎么办", # 应命中
"密码重置步骤", # 应命中
"怎么修改登录密码", # 应命中
"服务器配置参数", # 不应命中
]
print("=== 语义匹配演示 ===\n")
print(f"原始: '{original}'")
print(f"答案: {answer}\n")
for q in test_queries:
result = cache.get(q)
hit = "✅ 命中" if result else "❌ 未命中"
print(f"查询: '{q}' → {hit}")
if __name__ == "__main__":
demo_ai_cache()
print()
demo_semantic_matching()
运行演示代码,你会看到:
- 第一次查询「公司的年假政策是什么」:未命中 → 执行 LLM 调用(0.5s 延迟)
- 第二次查询「Please explain annual leave policy」:语义命中 → 立即返回缓存结果(毫秒级)
- 第三次查询「公司的年假政策是什么」:精确命中 → 毫秒级响应
- 第四次查询「怎么请假」:不相关 → 未命中(正确行为)
十、缓存性能指标与监控
好的缓存系统需要可观测。以下是关键的监控指标和实现:
class CacheMonitor:
"""
缓存监控器:记录、聚合、报告缓存性能指标
"""
def __init__(self, cache: AICacheSystem, window_size: int = 3600):
self.cache = cache
self.window_size = window_size
self.timeline: list[dict] = []
def snapshot(self) -> dict:
"""生成当前缓存快照"""
stats = self.cache.summary()
now = time.time()
snapshot = {
"timestamp": now,
"stats": stats,
"latency_p95": self._estimate_p95_latency(),
"memory_usage_mb": self.cache.exact_cache.lru.current_memory
/ (1024 * 1024),
"eviction_rate": self._eviction_rate(),
}
self.timeline.append(snapshot)
# 只保留时间窗口内的快照
cutoff = now - self.window_size
self.timeline = [s for s in self.timeline if s["timestamp"] > cutoff]
return snapshot
def _estimate_p95_latency(self) -> float:
"""估算 P95 延迟(单位:毫秒)"""
recent = self.timeline[-100:] if len(self.timeline) > 100 else self.timeline
if not recent:
return 0.0
latencies = sorted(
s.get("latency_p95", 0) for s in recent
)
idx = int(len(latencies) * 0.95)
return latencies[idx] if idx < len(latencies) else latencies[-1]
def _eviction_rate(self) -> float:
"""计算最近窗口内的淘汰率"""
if len(self.timeline) < 2:
return 0.0
recent = self.timeline[-2:]
evictions_diff = (
recent[1]["stats"]["exact_hits"]
+ recent[1]["stats"]["semantic_hits"]
+ recent[1]["stats"]["misses"]
) - (
recent[0]["stats"]["exact_hits"]
+ recent[0]["stats"]["semantic_hits"]
+ recent[0]["stats"]["misses"]
)
return max(0, evictions_diff) / (len(self.cache.exact_cache.lru) + 1)
def report(self) -> str:
"""生成人类可读的性能报告"""
s = self.snapshot()
stats = s["stats"]
report = f"""
=== AI 缓存性能报告 ===
时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(s['timestamp']))}
命中统计:
精确命中: {stats['exact_hits']}
语义命中: {stats['semantic_hits']}
总未命中: {stats['misses']}
综合命中率: {stats['hit_rate']*100:.1f}%
缓存规模:
精确缓存条目: {stats['exact_entries']}
语义缓存条目: {stats['semantic_entries']}
内存占用: {s['memory_usage_mb']:.1f} MB
预估值 P95 延迟: {s['latency_p95']:.1f} ms
"""
return report.strip()
关键的缓存性能指标解读:
- 命中率:综合命中率 60% 是底线。低于 60% 说明缓存策略需要调整——可能相似度阈值过高(语义缓存几乎不工作),或者 TTL 太短导致频繁过期。
- 语义命中率:在综合命中率中,语义命中占比越高,说明系统对「不同表述同一问题」的缓存效果越好。如果精确命中率很高而语义命中率很低,说明用户查询高度标准化,语义缓存收益有限。
- 淘汰率:高频淘汰意味着缓存容量不足。如果在内存还有余量的情况下淘汰率高,说明 LRU 的容量上限设得太低。
十一、常见问题与优化策略
11.1 缓存雪崩
大量缓存同时过期,导致所有请求穿透到后端。
解决方案:过期时间随机化
def _randomized_ttl(base_ttl: int, jitter: float = 0.1) -> int:
"""在基础 TTL 上增加随机抖动,防止缓存雪崩"""
import random
jitter_range = int(base_ttl * jitter)
return base_ttl + random.randint(-jitter_range, jitter_range)
11.2 缓存击穿
热点键过期瞬间,大量并发请求同时查询同一个未命中的键。
解决方案:互斥锁 + 双重检查
class HotKeyProtector:
"""热点键保护:防止缓存击穿"""
def __init__(self):
self._locks: dict[str, threading.Lock] = {}
self._global_lock = threading.Lock()
def _get_lock(self, key: str) -> threading.Lock:
with self._global_lock:
if key not in self._locks:
self._locks[key] = threading.Lock()
return self._locks[key]
def protect(self, key: str, compute_fn: Callable[[], Any]) -> Any:
"""
互斥锁保护:
同一时刻只有一个线程在计算,其余等待
"""
lock = self._get_lock(key)
with lock:
return compute_fn()
11.3 缓存预热策略
系统重启后,缓存为空。可以使用滚动预热策略:
class RollingWarmUp:
"""
滚动预热:从慢到快逐步加载缓存数据
避免一次性加载大量数据导致系统抖动
"""
def __init__(self, cache: AICacheSystem, loader: Callable[[int, int], list]):
self.cache = cache
self.loader = loader
def warm_up(self, total: int, batch_size: int = 50, delay: float = 0.1):
"""分批加载缓存数据"""
loaded = 0
while loaded < total:
batch = self.loader(loaded, min(batch_size, total - loaded))
for query, value, context in batch:
self.cache.set(query, value, context)
loaded += len(batch)
print(
f"[WarmUp] {loaded}/{total} ({loaded*100//total}%)"
)
if loaded < total:
time.sleep(delay) # 防止系统过载
11.4 缓存容量规划
| 组件 | 内存估算公式 | 100万条目示例 |
|---|---|---|
| LRU 链表 | key + value + 指针 ≈ 200 bytes | 200 MB |
| 语义缓存 | 384维 float32 = 1536 bytes + 字符串 | 1.5 GB |
| 索引结构 | 近似 10% 额外开销 | 150 MB |
实践中,精确缓存建议分配 256-512MB,语义缓存根据嵌入向量维度计算。384 维向量每条 1.5KB,5000 条约 7.5MB,内存开销可控。
十二、总结
本文从零实现了 AI 缓存系统的完整架构:
- 精确缓存层:基于 LRU + 双向链表的 O(1) 查找,结合 TTL 过期和内存上限控制
- 语义缓存层:嵌入向量 + 余弦相似度,识别语义相似但表述不同的查询
- 缓存策略引擎:两级缓存穿透保护,写穿透保证数据一致性
- 运维工具:缓存监控、性能指标、预热策略、雪崩和击穿防护
选择缓存策略时的指导原则:
- 查询固定且重复率高:只需精确缓存,语义缓存收益有限
- 自然语言查询、用户表述多样:必须启用语义缓存,相似度阈值从 0.92 开始调优
- 热点数据频繁更新:缩短 TTL(300-600秒)+ 写穿透模式
- 冷启动场景:实现缓存预热 + 滚动加载,避免缓存空窗期
AI 缓存是一个「一分设计、九分调优」的工程。先根据业务场景选择合适的缓存策略组合,然后通过监控指标持续调整参数。理解和掌握缓存系统的底层原理,远比简单接入 Redis 等中间件更有价值——因为你可以在任何时候根据自己的场景需求,做出最优的设计决策。
📚 延伸阅读
如果你对 DeepSeek 的实战 用法感兴趣,推荐阅读我的另一篇文章:
👉 DeepSeek 实战指南:提示词工程、API 集成与效率提升全攻略
这篇文章系统地拆解了 DeepSeek 的提示词工程技巧、API 封装方法以及日常效率提升场景,全文代码可直接运行,适合已经上手 DeepSeek 但希望更高效使用的开发者。
本文是"手写 AI 系统"系列文章之一。该系列从零实现 AI 系统中的关键组件,涵盖 RAG、Agent、Function Calling、MCP 等核心技术,帮助你深入理解底层原理,构建属于自己的 AI 工具。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)