一、为什么 AI 系统需要缓存

在实际的 AI 应用中,重复计算是一个普遍存在的性能瓶颈。以 RAG 系统为例,当多个用户询问相似问题(如"公司的薪酬政策是什么"和"请说明薪酬规定"),如果每次都重新执行嵌入向量检索、大模型推理等全套流程,不仅响应延迟高,还会产生大量 API 调用开销。

缓存的核心价值体现在三个维度:

降低延迟:缓存命中时直接返回结果,避免大模型推理的 3-10 秒等待。对于实时对话系统,这直接决定了用户体验的流畅度。

节省成本:LLM API 按 token 计费。一个 2000 token 的查询响应,在 GPT-4 上约需 0.06 元。即使每天只有 1000 次重复查询,一年就能省下 2 万元以上。

系统保护:缓存吸收突发流量峰值,防止后端模型服务过载。这在并发高峰期尤为重要。

本文将从零构建一个完整的 AI 缓存系统,覆盖精确缓存、语义缓存、LRU 淘汰、TTL 过期等核心机制。全部代码手写实现,不依赖任何缓存框架。

二、缓存系统架构总览

一个完整的 AI 缓存系统包含以下层次:

用户请求
    ↓
[ 缓存层 ]
    ├── 一级缓存:精确匹配(内存哈希表)
    ├── 二级缓存:语义匹配(向量相似度检索)
    └── 三级缓存:持久化存储(可选 Redis/磁盘)
    ↓
[ 缓存策略引擎 ]
    ├── LRU 淘汰策略
    ├── TTL 过期策略
    └── 预热与写穿透策略
    ↓
[ 实际计算 ]
    └── LLM 推理 / RAG 检索

第一层精确缓存用于完全相同的查询,毫秒级响应。第二层语义缓存用于含义相似但表述不同的查询,需要向量化 + 相似度检索。第三层持久化缓存用于系统重启后恢复热点数据。

三、基础数据结构:双向链表 + 哈希表

LRU 缓存的核心数据结构是「双向链表 + 哈希表」组合。双向链表维护访问顺序,哈希表提供 O(1) 查找。

from __future__ import annotations
import hashlib
import json
import time
import threading
from typing import Any, Optional, Callable
from dataclasses import dataclass, field
import numpy as np


@dataclass
class CacheNode:
    """缓存节点,双向链表的基本单元"""
    key: str
    value: Any
    prev: Optional[CacheNode] = None
    next: Optional[CacheNode] = None
    size: int = 0       # 缓存项大小(字节)
    hits: int = 0        # 访问次数
    created_at: float = 0.0
    accessed_at: float = 0.0


class LRUCache:
    """
    基于双向链表 + 哈希表的 LRU 缓存
    特点:O(1) 查找、插入、删除
    """

    def __init__(self, capacity: int = 1000, max_memory_mb: int = 256):
        self.capacity = capacity
        self.max_memory = max_memory_mb * 1024 * 1024  # 转为字节
        self.cache: dict[str, CacheNode] = {}
        self.head = CacheNode(key="__head__", value=None)  # 伪头节点
        self.tail = CacheNode(key="__tail__", value=None)  # 伪尾节点
        self.head.next = self.tail
        self.tail.prev = self.head
        self.current_memory = 0
        self.lock = threading.RLock()  # 线程安全
        self._init_metrics()

    def _init_metrics(self):
        self.metrics = {
            "hits": 0,
            "misses": 0,
            "evictions": 0,
            "total_lookups": 0,
        }

    def _remove_node(self, node: CacheNode):
        """从链表中移除节点(不操作哈希表)"""
        prev_node = node.prev
        next_node = node.next
        prev_node.next = next_node
        next_node.prev = prev_node

    def _add_to_head(self, node: CacheNode):
        """将节点插入到头部(最近使用)"""
        node.prev = self.head
        node.next = self.head.next
        self.head.next.prev = node
        self.head.next = node

    def _move_to_head(self, node: CacheNode):
        """将已存在节点移到头部"""
        self._remove_node(node)
        self._add_to_head(node)

    def _evict_tail(self) -> Optional[CacheNode]:
        """移除最久未使用的节点(尾部)"""
        if self.tail.prev == self.head:
            return None
        node = self.tail.prev
        self._remove_node(node)
        del self.cache[node.key]
        self.current_memory -= node.size
        self.metrics["evictions"] += 1
        return node

    def get(self, key: str) -> Optional[Any]:
        """获取缓存项,同时更新访问顺序"""
        with self.lock:
            self.metrics["total_lookups"] += 1
            node = self.cache.get(key)
            if node is None:
                self.metrics["misses"] += 1
                return None
            self._move_to_head(node)
            node.hits += 1
            node.accessed_at = time.time()
            self.metrics["hits"] += 1
            return node.value

    def put(self, key: str, value: Any, cost: int = 0):
        """插入或更新缓存项"""
        with self.lock:
            now = time.time()
            node = self.cache.get(key)
            if node is not None:
                # 更新已有节点
                self.current_memory -= node.size
                node.value = value
                node.size = cost
                node.accessed_at = now
                node.created_at = now if node.created_at == 0 else node.created_at
                self._move_to_head(node)
                self.current_memory += cost
                return

            # 新建节点
            node = CacheNode(
                key=key,
                value=value,
                size=cost,
                hits=0,
                created_at=now,
                accessed_at=now,
            )

            # 检查容量和内存限制
            while (
                len(self.cache) >= self.capacity
                or self.current_memory + cost > self.max_memory
            ):
                if not self._evict_tail():
                    break

            self.cache[key] = node
            self._add_to_head(node)
            self.current_memory += cost

    def invalidate(self, key: str):
        """主动失效指定缓存项"""
        with self.lock:
            node = self.cache.pop(key, None)
            if node:
                self.current_memory -= node.size
                self._remove_node(node)

    def clear(self):
        """清空所有缓存"""
        with self.lock:
            self.cache.clear()
            self.head.next = self.tail
            self.tail.prev = self.head
            self.current_memory = 0
            self._init_metrics()

    def hit_rate(self) -> float:
        """缓存命中率"""
        total = self.metrics["hits"] + self.metrics["misses"]
        return self.metrics["hits"] / total if total > 0 else 0.0

    def __len__(self):
        return len(self.cache)

    def __contains__(self, key: str):
        with self.lock:
            return key in self.cache

这段代码的核心设计思想是使用哨兵节点(dummy head/tail)消除边界条件判断。在插入和删除操作中,头尾虚拟节点让代码不需要单独处理空链表情况,显著减少 bug 来源。

threading.RLock(可重入锁)保证线程安全,因为在 put 方法内可能触发 _evict_tail,如果使用普通锁 Lock,同一线程内再次获取同一把锁会造成死锁。

四、精确缓存:查询键的标准化

精确缓存要求完全相同的关键字才能命中。但用户查询可能有细微差异(多余空格、大小写、标点符号),我们需要先对查询进行键标准化

class ExactCache:
    """
    精确缓存层:完全匹配查询
    支持查询键标准化和 TTL 过期
    """

    def __init__(self, capacity: int = 10000, default_ttl: int = 3600):
        self.lru = LRUCache(capacity=capacity)
        self.default_ttl = default_ttl
        self.ttl_map: dict[str, float] = {}  # key → 过期时间戳

    @staticmethod
    def normalize_query(query: str) -> str:
        """标准化查询字符串,提高命中率"""
        # 折叠空白
        normalized = " ".join(query.split())
        # 转小写
        normalized = normalized.lower().strip()
        # 统一标点
        normalized = normalized.replace("?", "?").replace("!", "!")
        # 移除末尾句号
        normalized = normalized.rstrip("。.")
        return normalized

    def make_key(self, query: str, context: Optional[dict] = None) -> str:
        """生成缓存键:查询 + 上下文哈希"""
        normalized = self.normalize_query(query)
        if context:
            context_str = json.dumps(context, sort_keys=True)
            context_hash = hashlib.md5(context_str.encode()).hexdigest()[:8]
            return f"{normalized}::ctx:{context_hash}"
        return normalized

    def get(self, query: str, context: Optional[dict] = None) -> Optional[Any]:
        key = self.make_key(query, context)
        # 检查 TTL
        if key in self.ttl_map and time.time() > self.ttl_map[key]:
            self.lru.invalidate(key)
            del self.ttl_map[key]
            return None
        return self.lru.get(key)

    def set(
        self,
        query: str,
        value: Any,
        context: Optional[dict] = None,
        ttl: Optional[int] = None,
    ):
        key = self.make_key(query, context)
        cost = len(json.dumps(value, default=str).encode("utf-8"))
        self.lru.put(key, value, cost=cost)
        self.ttl_map[key] = time.time() + (ttl or self.default_ttl)

    def hit_rate(self) -> float:
        return self.lru.hit_rate()

键标准化中," ".join(query.split()) 是非常实用的技巧:它同时完成了去除首尾空格折叠中间连续空格两件事,比 strip() + 正则替换更简洁高效。

make_key 中的 context 参数支持按场景隔离缓存。例如,同一个问题"今天天气如何"在不同城市上下文中应该有不同的缓存。

五、语义缓存:向量相似度匹配

语义缓存是 AI 缓存的精髓所在。当用户问"如何重置密码"和"忘了密码怎么办"时,虽然字符串不同,但语义高度相似,应该复用同一份缓存结果。

实现语义缓存包含三个步骤:

  1. 查询向量化:将文本转为高维向量
  2. 相似度检索:在已有缓存向量库中查找最近邻
  3. 相似度阈值判断:高于阈值则视为语义匹配
@dataclass
class SemanticEntry:
    """语义缓存条目"""
    query: str           # 原始查询
    key: str             # 精确缓存键
    embedding: np.ndarray  # 嵌入向量
    result: Any          # 缓存结果
    access_count: int = 0
    last_access: float = 0.0


class SemanticCache:
    """
    语义缓存层:基于向量相似度匹配语义相近的查询
    采用扁平列表 + 暴力搜索(适合中小规模缓存)
    """

    def __init__(
        self,
        embed_func: Callable[[str], list[float]],
        similarity_threshold: float = 0.92,
        max_entries: int = 5000,
    ):
        self.embed_func = embed_func
        self.threshold = similarity_threshold
        self.max_entries = max_entries
        self.entries: dict[str, SemanticEntry] = {}  # key → entry
        self.lock = threading.RLock()

    @staticmethod
    def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
        """余弦相似度计算"""
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        if norm_a == 0 or norm_b == 0:
            return 0.0
        return float(np.dot(a, b) / (norm_a * norm_b))

    def find_similar(self, query: str) -> Optional[tuple[str, Any, float]]:
        """
        在已有缓存中查找语义相似的查询
        返回 (原查询key, 缓存结果, 相似度分数)
        """
        with self.lock:
            if not self.entries:
                return None

            query_vec = np.array(self.embed_func(query), dtype=np.float32)
            best_key = None
            best_sim = 0.0
            best_result = None

            for key, entry in self.entries.items():
                sim = self.cosine_similarity(query_vec, entry.embedding)
                if sim > best_sim:
                    best_sim = sim
                    best_key = key
                    best_result = entry.result

            if best_sim >= self.threshold and best_key is not None:
                self.entries[best_key].access_count += 1
                self.entries[best_key].last_access = time.time()
                return (best_key, best_result, best_sim)

            return None

    def add(self, query: str, result: Any, exact_key: str):
        """添加新的语义缓存条目"""
        with self.lock:
            if len(self.entries) >= self.max_entries:
                # 淘汰最久未访问的条目
                oldest_key = min(
                    self.entries, key=lambda k: self.entries[k].last_access
                )
                del self.entries[oldest_key]

            embedding = np.array(
                self.embed_func(query), dtype=np.float32
            )
            self.entries[exact_key] = SemanticEntry(
                query=query,
                key=exact_key,
                embedding=embedding,
                result=result,
                last_access=time.time(),
            )

    def evict(self, exact_key: str):
        """移除指定缓存(与精确缓存保持同步)"""
        with self.lock:
            self.entries.pop(exact_key, None)

这里的关键参数是 similarity_threshold(相似度阈值)。0.92 是一个经过实践检验的起始值——太低会误匹配(把"如何写 Python 代码"和"如何写 Java 代码"当成相同问题),太高则退化为精确匹配。

余弦相似度的计算实现中,需要注意零向量保护:np.linalg.norm 返回 0 时做除法会得到 NaN,导致后续比较逻辑异常。

六、缓存策略引擎:TTL + 自适应淘汰

将精确缓存和语义缓存整合为一个统一的缓存系统,再加入 TTL 过期管理和自适应淘汰策略:

class AICacheSystem:
    """
    AI 缓存系统:整合精确缓存 + 语义缓存 + 策略管理
    完整的读写穿透接口
    """

    def __init__(
        self,
        embed_func: Callable[[str], list[float]],
        exact_capacity: int = 10000,
        semantic_threshold: float = 0.92,
        default_ttl: int = 3600,
        enable_semantic: bool = True,
    ):
        self.exact_cache = ExactCache(capacity=exact_capacity, default_ttl=default_ttl)
        self.semantic_cache = SemanticCache(
            embed_func=embed_func,
            similarity_threshold=semantic_threshold,
        ) if enable_semantic else None
        self.enable_semantic = enable_semantic
        self.default_ttl = default_ttl
        self.embed_func = embed_func
        self.lock = threading.RLock()
        self.stats = {
            "exact_hits": 0,
            "semantic_hits": 0,
            "misses": 0,
        }

    def get(
        self,
        query: str,
        context: Optional[dict] = None,
    ) -> Optional[Any]:
        """
        两级缓存读取:
        1. 先查精确缓存
        2. 未命中则查语义缓存
        3. 都未命中返回 None
        """
        # 第一级:精确缓存
        result = self.exact_cache.get(query, context)
        if result is not None:
            self.stats["exact_hits"] += 1
            return result

        # 第二级:语义缓存
        if self.enable_semantic:
            similar = self.semantic_cache.find_similar(query)
            if similar is not None:
                similar_key, result, similarity = similar
                self.stats["semantic_hits"] += 1
                # 将语义命中的查询也写入精确缓存,下次秒级响应
                self.exact_cache.set(query, result, context, ttl=600)
                return result

        self.stats["misses"] += 1
        return None

    def set(
        self,
        query: str,
        value: Any,
        context: Optional[dict] = None,
        ttl: Optional[int] = None,
    ):
        """写入缓存(两级同步写入)"""
        exact_key = self.exact_cache.make_key(query, context)

        # 写入精确缓存
        self.exact_cache.set(query, value, context, ttl=ttl)

        # 写入语义缓存
        if self.enable_semantic:
            self.semantic_cache.add(query, value, exact_key)

    def get_or_compute(
        self,
        query: str,
        compute_fn: Callable[[], Any],
        context: Optional[dict] = None,
        ttl: Optional[int] = None,
    ) -> Any:
        """
        缓存穿透读取模式:
        命中 → 返回缓存值
        未命中 → 调用 compute_fn → 缓存结果 → 返回
        """
        result = self.get(query, context)
        if result is not None:
            return result

        result = compute_fn()
        self.set(query, result, context, ttl=ttl)
        return result

    def summary(self) -> dict:
        """缓存系统运行统计"""
        total = sum(self.stats.values())
        return {
            "exact_hits": self.stats["exact_hits"],
            "semantic_hits": self.stats["semantic_hits"],
            "misses": self.stats["misses"],
            "total": total,
            "hit_rate": (self.stats["exact_hits"] + self.stats["semantic_hits"])
            / total if total > 0 else 0.0,
            "semantic_rate": self.stats["semantic_hits"] / total if total > 0 else 0.0,
            "exact_hit_rate": self.exact_cache.hit_rate(),
            "exact_entries": len(self.exact_cache.lru),
            "semantic_entries": len(self.semantic_cache.entries)
            if self.semantic_cache else 0,
        }

get_or_compute 是缓存系统的核心接口模式。调用方只需要传入「查询」和「计算函数」,缓存系统自动处理命中判断和数据写入。这种模式也称为缓存穿透保护——避免在缓存未命中时多个并发请求同时触发昂贵的计算。

七、嵌入向量服务的封装

语义缓存依赖嵌入向量生成,我们封装一个通用的嵌入服务和分词器:

class EmbeddingService:
    """
    嵌入向量生成服务
    支持本地模型和远程 API 两种模式
    """

    def __init__(self, mode: str = "api", api_url: str = None, api_key: str = None):
        self.mode = mode
        self.api_url = api_url
        self.api_key = api_key
        self._model = None
        self._tokenizer = None

        if mode == "local":
            self._init_local_model()

    def _init_local_model(self):
        """初始化本地嵌入模型(以 sentence-transformers 为例)"""
        try:
            from sentence_transformers import SentenceTransformer
            # 使用轻量级多语言模型
            self._model = SentenceTransformer(
                "paraphrase-multilingual-MiniLM-L12-v2"
            )
        except ImportError:
            raise ImportError(
                "需要安装 sentence-transformers: pip install sentence-transformers"
            )

    def embed(self, text: str) -> list[float]:
        """生成文本的嵌入向量"""
        if self.mode == "local" and self._model:
            vec = self._model.encode(text, normalize_embeddings=True)
            return vec.tolist()
        elif self.mode == "api":
            return self._api_embed(text)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

    def _api_embed(self, text: str) -> list[float]:
        """通过 API 生成嵌入向量"""
        import requests

        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": "text-embedding-v2",
            "input": text,
        }
        resp = requests.post(
            f"{self.api_url}/embeddings",
            headers=headers,
            json=payload,
            timeout=10,
        )
        resp.raise_for_status()
        data = resp.json()
        return data["data"][0]["embedding"]

    def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """批量生成嵌入向量"""
        if self.mode == "local" and self._model:
            vecs = self._model.encode(texts, normalize_embeddings=True)
            return vecs.tolist()
        else:
            return [self.embed(t) for t in texts]


class QueryTokenizer:
    """
    查询分词与特征提取
    用于缓存键生成和语义分析的预处理
    """

    def __init__(self, max_length: int = 512):
        self.max_length = max_length

    def tokenize(self, text: str) -> dict:
        """对查询做基础分词,返回统计特征"""
        import re

        text = text[:self.max_length]
        words = re.findall(r'\w+', text.lower())

        return {
            "word_count": len(words),
            "char_count": len(text),
            "has_question": "?" in text or "?" in text,
            "first_word": words[0] if words else "",
            "keyword_set": set(words),
        }

    def jaccard_similarity(self, query_a: str, query_b: str) -> float:
        """
        Jaccard 相似度——作为语义缓存的快速预过滤
        如果关键词完全不重叠,大概率语义也不同
        """
        feat_a = self.tokenize(query_a)
        feat_b = self.tokenize(query_b)
        set_a = feat_a["keyword_set"]
        set_b = feat_b["keyword_set"]

        intersection = len(set_a & set_b)
        union = len(set_a | set_b)
        return intersection / union if union > 0 else 0.0

嵌入服务支持本地和远程两种模式。本地模式使用 paraphrase-multilingual-MiniLM-L12-v2——一个 384 维的轻量多语言模型,66MB 内存占用,CPU 上单次编码只需 5-10ms,非常适合实时缓存场景。

QueryTokenizer 中的 Jaccard 相似度可作为语义缓存的快速预过滤——如果两个查询的关键词完全不重叠(Jaccard = 0),大概率语义也不同,无需浪费嵌入向量的相似度计算。

八、写穿透与缓存预热

生产环境中,缓存数据和真实数据可能不一致。写穿透(Write-Through)策略保证每次写入都同步更新缓存:

class WriteThroughCache(AICacheSystem):
    """
    写穿透缓存装饰器
    每次写入数据时同步更新缓存
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._write_hooks: list[Callable] = []

    def on_write(self, hook: Callable[[str, Any], None]):
        """注册写入回调(用于持久化同步)"""
        self._write_hooks.append(hook)

    def write_through(
        self,
        key: str,
        value: Any,
        context: Optional[dict] = None,
        ttl: Optional[int] = None,
        sync_persistence: bool = True,
    ):
        """写穿透:写入缓存 + 触发持久化回调"""
        self.set(key, value, context, ttl=ttl)

        if sync_persistence:
            for hook in self._write_hooks:
                try:
                    hook(key, value)
                except Exception as e:
                    print(f"[Cache] Write hook failed: {e}")

    def warm_up(self, data: list[tuple[str, Any, Optional[dict]]]):
        """
        缓存预热:批量加载热点数据
        一般在系统启动时调用
        """
        for query, value, context in data:
            self.set(query, value, context)
        print(
            f"[Cache] Warm-up complete: {len(data)} entries loaded"
        )


class CachePersistence:
    """缓存持久化:磁盘备份与恢复"""

    def __init__(self, filepath: str = "/tmp/ai_cache_backup.json"):
        self.filepath = filepath

    def save(self, data: dict):
        with open(self.filepath, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, default=str)

    def load(self) -> dict:
        try:
            with open(self.filepath, "r", encoding="utf-8") as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            return {}

    def periodic_save(self, cache: AICacheSystem, interval: int = 300):
        """定时持久化缓存快照"""
        import threading

        def _save_loop():
            while True:
                time.sleep(interval)
                snapshot = cache.summary()
                self.save({"timestamp": time.time(), "stats": snapshot})

        thread = threading.Thread(target=_save_loop, daemon=True)
        thread.start()

写穿透确保缓存与后端数据源的一致性。on_write 注册的回调可以对接数据库写入、日志记录或消息队列推送。

预热接口 warm_up 在系统启动时调用,加载前一天的热点数据,避免刚启动时的缓存空窗期。

九、完整使用示例

将以上所有组件组合起来,演示一个真实的 AI 问答缓存场景:

def demo_ai_cache():
    """
    AI 缓存系统完整演示
    使用一个简单的模拟嵌入函数(真实环境应替换为实际嵌入服务)
    """
    import random

    # 模拟嵌入函数(演示用)
    def mock_embed(text: str) -> list[float]:
        random.seed(hash(text) % (2**31))
        return [random.random() for _ in range(4)]

    # 构建缓存系统
    cache = AICacheSystem(
        embed_func=mock_embed,
        exact_capacity=1000,
        semantic_threshold=0.85,  # 演示用较低阈值
        default_ttl=3600,
    )

    # 模拟 LLM 调用
    def expensive_llm_call(query: str) -> str:
        print(f"  [LLM] 正在计算: '{query}'")
        time.sleep(0.5)  # 模拟推理延迟
        return f"这是关于「{query}」的回答。"

    # 模拟查询序列
    queries = [
        ("公司的年假政策是什么", False),
        ("Please explain annual leave policy", False),
        ("年假几天", False),
        ("年假政策", False),       # 应命中语义缓存
        ("公司的年假政策是什么", False),  # 应命中精确缓存
        ("怎么请假", True),         # 不相关,不应命中
    ]

    print("=== AI 缓存系统演示 ===\n")

    for i, (query, should_miss) in enumerate(queries):
        print(f"查询 {i+1}: '{query}'")
        start = time.time()

        result = cache.get(query)
        if result is not None:
            elapsed = time.time() - start
            print(f"  [缓存] ✅ 命中! ({elapsed:.3f}s)")
        else:
            print(f"  [缓存] ❌ 未命中,执行 LLM 调用...")
            result = expensive_llm_call(query)
            cache.set(query, result)
            elapsed = time.time() - start
            print(f"  [完成] ({elapsed:.3f}s)")

        print(f"  结果: {result}\n")

    # 打印统计
    print("=== 缓存统计 ===")
    stats = cache.summary()
    for k, v in stats.items():
        print(f"  {k}: {v}")


def demo_semantic_matching():
    """
    语义缓存匹配演示
    展示不同表述如何命中同一个缓存结果
    """
    def mock_embed(text: str) -> list[float]:
        # 为演示设计:关键词越相似,向量越接近
        keywords = ["密码", "重置", "忘记", "登录", "账户", "修改"]
        vec = [0.0] * len(keywords)
        for i, kw in enumerate(keywords):
            if kw in text:
                vec[i] = 1.0
        # 归一化
        norm = sum(v * v for v in vec) ** 0.5
        return [v / norm for v in vec] if norm > 0 else vec

    cache = AICacheSystem(
        embed_func=mock_embed,
        semantic_threshold=0.75,
    )

    # 将原始问题缓存
    original = "如何重置密码?"
    answer = "步骤:1. 点击登录页"忘记密码" 2. 输入注册邮箱 3. 查收重置链接 4. 设置新密码"
    cache.set(original, answer)

    # 测试语义相近的查询
    test_queries = [
        "忘记密码怎么办",   # 应命中
        "密码重置步骤",     # 应命中
        "怎么修改登录密码", # 应命中
        "服务器配置参数",   # 不应命中
    ]

    print("=== 语义匹配演示 ===\n")
    print(f"原始: '{original}'")
    print(f"答案: {answer}\n")

    for q in test_queries:
        result = cache.get(q)
        hit = "✅ 命中" if result else "❌ 未命中"
        print(f"查询: '{q}' → {hit}")


if __name__ == "__main__":
    demo_ai_cache()
    print()
    demo_semantic_matching()

运行演示代码,你会看到:
- 第一次查询「公司的年假政策是什么」:未命中 → 执行 LLM 调用(0.5s 延迟)
- 第二次查询「Please explain annual leave policy」:语义命中 → 立即返回缓存结果(毫秒级)
- 第三次查询「公司的年假政策是什么」:精确命中 → 毫秒级响应
- 第四次查询「怎么请假」:不相关 → 未命中(正确行为)

十、缓存性能指标与监控

好的缓存系统需要可观测。以下是关键的监控指标和实现:

class CacheMonitor:
    """
    缓存监控器:记录、聚合、报告缓存性能指标
    """

    def __init__(self, cache: AICacheSystem, window_size: int = 3600):
        self.cache = cache
        self.window_size = window_size
        self.timeline: list[dict] = []

    def snapshot(self) -> dict:
        """生成当前缓存快照"""
        stats = self.cache.summary()
        now = time.time()

        snapshot = {
            "timestamp": now,
            "stats": stats,
            "latency_p95": self._estimate_p95_latency(),
            "memory_usage_mb": self.cache.exact_cache.lru.current_memory
            / (1024 * 1024),
            "eviction_rate": self._eviction_rate(),
        }
        self.timeline.append(snapshot)

        # 只保留时间窗口内的快照
        cutoff = now - self.window_size
        self.timeline = [s for s in self.timeline if s["timestamp"] > cutoff]

        return snapshot

    def _estimate_p95_latency(self) -> float:
        """估算 P95 延迟(单位:毫秒)"""
        recent = self.timeline[-100:] if len(self.timeline) > 100 else self.timeline
        if not recent:
            return 0.0
        latencies = sorted(
            s.get("latency_p95", 0) for s in recent
        )
        idx = int(len(latencies) * 0.95)
        return latencies[idx] if idx < len(latencies) else latencies[-1]

    def _eviction_rate(self) -> float:
        """计算最近窗口内的淘汰率"""
        if len(self.timeline) < 2:
            return 0.0
        recent = self.timeline[-2:]
        evictions_diff = (
            recent[1]["stats"]["exact_hits"]
            + recent[1]["stats"]["semantic_hits"]
            + recent[1]["stats"]["misses"]
        ) - (
            recent[0]["stats"]["exact_hits"]
            + recent[0]["stats"]["semantic_hits"]
            + recent[0]["stats"]["misses"]
        )
        return max(0, evictions_diff) / (len(self.cache.exact_cache.lru) + 1)

    def report(self) -> str:
        """生成人类可读的性能报告"""
        s = self.snapshot()
        stats = s["stats"]

        report = f"""
=== AI 缓存性能报告 ===
时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(s['timestamp']))}

命中统计:
  精确命中: {stats['exact_hits']}
  语义命中: {stats['semantic_hits']}
  总未命中: {stats['misses']}
  综合命中率: {stats['hit_rate']*100:.1f}%

缓存规模:
  精确缓存条目: {stats['exact_entries']}
  语义缓存条目: {stats['semantic_entries']}
  内存占用: {s['memory_usage_mb']:.1f} MB
  预估值 P95 延迟: {s['latency_p95']:.1f} ms
"""
        return report.strip()

关键的缓存性能指标解读

  • 命中率:综合命中率 60% 是底线。低于 60% 说明缓存策略需要调整——可能相似度阈值过高(语义缓存几乎不工作),或者 TTL 太短导致频繁过期。
  • 语义命中率:在综合命中率中,语义命中占比越高,说明系统对「不同表述同一问题」的缓存效果越好。如果精确命中率很高而语义命中率很低,说明用户查询高度标准化,语义缓存收益有限。
  • 淘汰率:高频淘汰意味着缓存容量不足。如果在内存还有余量的情况下淘汰率高,说明 LRU 的容量上限设得太低。

十一、常见问题与优化策略

11.1 缓存雪崩

大量缓存同时过期,导致所有请求穿透到后端。

解决方案:过期时间随机化

def _randomized_ttl(base_ttl: int, jitter: float = 0.1) -> int:
    """在基础 TTL 上增加随机抖动,防止缓存雪崩"""
    import random
    jitter_range = int(base_ttl * jitter)
    return base_ttl + random.randint(-jitter_range, jitter_range)

11.2 缓存击穿

热点键过期瞬间,大量并发请求同时查询同一个未命中的键。

解决方案:互斥锁 + 双重检查

class HotKeyProtector:
    """热点键保护:防止缓存击穿"""

    def __init__(self):
        self._locks: dict[str, threading.Lock] = {}
        self._global_lock = threading.Lock()

    def _get_lock(self, key: str) -> threading.Lock:
        with self._global_lock:
            if key not in self._locks:
                self._locks[key] = threading.Lock()
            return self._locks[key]

    def protect(self, key: str, compute_fn: Callable[[], Any]) -> Any:
        """
        互斥锁保护:
        同一时刻只有一个线程在计算,其余等待
        """
        lock = self._get_lock(key)
        with lock:
            return compute_fn()

11.3 缓存预热策略

系统重启后,缓存为空。可以使用滚动预热策略:

class RollingWarmUp:
    """
    滚动预热:从慢到快逐步加载缓存数据
    避免一次性加载大量数据导致系统抖动
    """

    def __init__(self, cache: AICacheSystem, loader: Callable[[int, int], list]):
        self.cache = cache
        self.loader = loader

    def warm_up(self, total: int, batch_size: int = 50, delay: float = 0.1):
        """分批加载缓存数据"""
        loaded = 0
        while loaded < total:
            batch = self.loader(loaded, min(batch_size, total - loaded))
            for query, value, context in batch:
                self.cache.set(query, value, context)
            loaded += len(batch)
            print(
                f"[WarmUp] {loaded}/{total} ({loaded*100//total}%)"
            )
            if loaded < total:
                time.sleep(delay)  # 防止系统过载

11.4 缓存容量规划

组件 内存估算公式 100万条目示例
LRU 链表 key + value + 指针 ≈ 200 bytes 200 MB
语义缓存 384维 float32 = 1536 bytes + 字符串 1.5 GB
索引结构 近似 10% 额外开销 150 MB

实践中,精确缓存建议分配 256-512MB,语义缓存根据嵌入向量维度计算。384 维向量每条 1.5KB,5000 条约 7.5MB,内存开销可控。

十二、总结

本文从零实现了 AI 缓存系统的完整架构:

  1. 精确缓存层:基于 LRU + 双向链表的 O(1) 查找,结合 TTL 过期和内存上限控制
  2. 语义缓存层:嵌入向量 + 余弦相似度,识别语义相似但表述不同的查询
  3. 缓存策略引擎:两级缓存穿透保护,写穿透保证数据一致性
  4. 运维工具:缓存监控、性能指标、预热策略、雪崩和击穿防护

选择缓存策略时的指导原则:

  • 查询固定且重复率高:只需精确缓存,语义缓存收益有限
  • 自然语言查询、用户表述多样:必须启用语义缓存,相似度阈值从 0.92 开始调优
  • 热点数据频繁更新:缩短 TTL(300-600秒)+ 写穿透模式
  • 冷启动场景:实现缓存预热 + 滚动加载,避免缓存空窗期

AI 缓存是一个「一分设计、九分调优」的工程。先根据业务场景选择合适的缓存策略组合,然后通过监控指标持续调整参数。理解和掌握缓存系统的底层原理,远比简单接入 Redis 等中间件更有价值——因为你可以在任何时候根据自己的场景需求,做出最优的设计决策。


📚 延伸阅读

如果你对 DeepSeek 的实战 用法感兴趣,推荐阅读我的另一篇文章:

👉 DeepSeek 实战指南:提示词工程、API 集成与效率提升全攻略

这篇文章系统地拆解了 DeepSeek 的提示词工程技巧、API 封装方法以及日常效率提升场景,全文代码可直接运行,适合已经上手 DeepSeek 但希望更高效使用的开发者。


本文是"手写 AI 系统"系列文章之一。该系列从零实现 AI 系统中的关键组件,涵盖 RAG、Agent、Function Calling、MCP 等核心技术,帮助你深入理解底层原理,构建属于自己的 AI 工具。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐