一、为什么需要推理缓存

1.1 重复请求的浪费

现实中的推理服务有一个被忽视的事实:大量请求是重复的。

以智慧安防为例:

  • 同一个摄像头,24 小时监控同一个入口
  • 每秒 25 帧,大部分帧的画面几乎一样
  • 对同一张人脸的识别结果,每秒被计算 25 次
  • 实际上,每秒只需要计算 1-2 次,其余 23 次是浪费

以电商搜索为例:

  • 用户搜索"手机壳",系统返回推荐结果
  • 10 秒后,另一个用户也搜索"手机壳"
  • 两次搜索的输入完全相同,但第二次重新计算了整个模型

量化分析:

假设推理延迟 10ms,QPS 1000,重复请求比例 30%。

  • 不用缓存:每秒计算量 1000 次,总计算时间 1000 × 10ms = 10 秒(需要多个 NPU 并行)
  • 用缓存:每秒实际计算 700 次(300 次命中缓存),缓存命中延迟 < 0.1ms(内存查找),节省 30% 的 NPU 算力

1.2 缓存的代价

缓存不是免费的,需要权衡:

收益:

  • 减少 NPU 计算量
  • 降低延迟(缓存命中 < 0.1ms vs 推理 10ms)
  • 提高吞吐量
  • 节省算力成本

代价:

  • 内存占用(缓存需要存储结果)
  • 一致性风险(模型更新后缓存可能过期)
  • 实现复杂度(缓存键生成、失效策略、并发控制)
  • 首次请求仍然慢(缓存未命中)

什么时候适合用缓存:

  • 输入空间有限(如固定类别分类)
  • 重复请求比例高
  • 对延迟敏感
  • 模型更新不频繁

什么时候不适合:

  • 输入几乎不重复(如随机生成的内容)
  • 内存非常紧张
  • 模型频繁更新(缓存频繁失效)

二、缓存策略

2.1 LRU 缓存

核心思想: 当缓存满了,淘汰最久没有被访问的条目。

为什么选 LRU?

  • 实现简单,性能好
  • 符合"时间局部性"原理:最近被访问的数据,很可能很快再次被访问
  • 适用于大多数推理场景

数据结构: 使用 OrderedDict 维护访问顺序。每次访问时,将条目移到末尾(最新)。淘汰时,删除头部(最旧)的条目。

时间复杂度: 查找 O(1)、插入 O(1)、淘汰 O(1)。

为什么不用 list? list 的删除是 O(n),当缓存很大时会变慢。OrderedDict 的删除是 O(1)。

import hashlib
import time
import threading
from collections import OrderedDict


class LRUCache:
    def __init__(self, max_size=1000, ttl_seconds=300):
        """
        参数:
            max_size: 缓存最大条目数
                太小: 命中率低
                太大: 内存占用高
                经验值: 预估重复请求数的 2-5 倍

            ttl_seconds: 缓存条目过期时间(秒)
                为什么要过期?
                1. 防止内存无限增长
                2. 模型更新后旧结果自动失效
                3. 控制缓存的新鲜度

                TTL 的选择:
                - 模型不更新: 可以设很长(如 1 小时)
                - 模型频繁更新: 设很短(如 10 秒)
                - 一般场景: 5-10 分钟
        """
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.cache = OrderedDict()  # {key: (value, timestamp)}
        self.lock = threading.RLock()  # 线程安全锁
        self.hits = 0
        self.misses = 0

    def get(self, key):
        """获取缓存"""
        with self.lock:
            if key in self.cache:
                value, timestamp = self.cache[key]
                # 检查是否过期
                if time.time() - timestamp < self.ttl_seconds:
                    # 命中: 移到末尾(标记为最近使用)
                    self.cache.move_to_end(key)
                    self.hits += 1
                    return value
                else:
                    # 过期: 删除
                    del self.cache[key]
            self.misses += 1
            return None

    def put(self, key, value):
        """放入缓存"""
        with self.lock:
            if key in self.cache:
                # 更新已有条目
                self.cache.move_to_end(key)
                self.cache[key] = (value, time.time())
            else:
                # 检查是否需要淘汰
                if len(self.cache) >= self.max_size:
                    # 删除最久未使用的(头部)
                    self.cache.popitem(last=False)
                # 插入新条目
                self.cache[key] = (value, time.time())

    def clear(self):
        """清空缓存"""
        with self.lock:
            self.cache.clear()
            self.hits = 0
            self.misses = 0

    def stats(self):
        """获取缓存统计"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate,
        }

使用示例:

cache = LRUCache(max_size=1000, ttl_seconds=300)


def infer_with_cache(model, input_tensor):
    """带缓存的推理"""
    # 生成缓存键(基于输入内容的哈希)
    cache_key = hashlib.md5(input_tensor.numpy().tobytes()).hexdigest()

    # 尝试从缓存获取
    cached_result = cache.get(cache_key)
    if cached_result is not None:
        return cached_result, "cache_hit"

    # 缓存未命中,执行推理
    with torch.no_grad():
        result = model(input_tensor.npu()).cpu()

    # 存入缓存
    cache.put(cache_key, result)

    return result, "cache_miss"


# 模拟 1000 个请求(30% 重复)
import random

inputs = [torch.randn(1, 3, 224, 224) for _ in range(700)]
# 加入 300 个重复请求
for _ in range(300):
    inputs.append(random.choice(inputs))
random.shuffle(inputs)

for inp in inputs:
    result, status = infer_with_cache(model, inp)

print(f"缓存统计: {cache.stats()}")
# 期望命中率: ~30%

2.2 LFU 缓存

核心思想: 当缓存满了,淘汰访问频率最低的条目。

与 LRU 的区别:

  • LRU:淘汰"最久没访问的"
  • LFU:淘汰"访问次数最少的"

适用场景:

  • 某些输入被反复请求(如热门商品的识别)
  • 希望保留高频条目,淘汰低频条目

实现方式: 使用两个数据结构——freq_map(频率 → 该频率的所有 key)和 key_map(key → 值和频率)。

class LFUCache:
    def __init__(self, max_size=1000, ttl_seconds=300):
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.key_map = {}      # {key: (value, freq, timestamp)}
        self.freq_map = {}     # {freq: OrderedDict({key: None})}
        self.min_freq = 0
        self.lock = threading.RLock()
        self.hits = 0
        self.misses = 0

    def get(self, key):
        """获取缓存并增加频率"""
        with self.lock:
            if key in self.key_map:
                value, freq, timestamp = self.key_map[key]
                if time.time() - timestamp < self.ttl_seconds:
                    self._increase_freq(key, freq)
                    self.hits += 1
                    return value
                else:
                    self._remove(key)
            self.misses += 1
            return None

    def put(self, key, value):
        """放入缓存"""
        with self.lock:
            if key in self.key_map:
                _, freq, _ = self.key_map[key]
                self.key_map[key] = (value, freq + 1, time.time())
                self._increase_freq(key, freq)
            else:
                if len(self.key_map) >= self.max_size:
                    self._evict()
                self.key_map[key] = (value, 1, time.time())
                self.min_freq = 1
                if 1 not in self.freq_map:
                    self.freq_map[1] = OrderedDict()
                self.freq_map[1][key] = None

    def _increase_freq(self, key, old_freq):
        """增加 key 的频率"""
        del self.freq_map[old_freq][key]
        if not self.freq_map[old_freq]:
            del self.freq_map[old_freq]
            if self.min_freq == old_freq:
                self.min_freq = old_freq + 1
        new_freq = old_freq + 1
        if new_freq not in self.freq_map:
            self.freq_map[new_freq] = OrderedDict()
        self.freq_map[new_freq][key] = None
        value, _, timestamp = self.key_map[key]
        self.key_map[key] = (value, new_freq, timestamp)

    def _evict(self):
        """淘汰频率最低的条目"""
        if self.min_freq in self.freq_map and self.freq_map[self.min_freq]:
            key, _ = self.freq_map[self.min_freq].popitem(last=False)
            if not self.freq_map[self.min_freq]:
                del self.freq_map[self.min_freq]
            del self.key_map[key]

    def _remove(self, key):
        """删除指定 key"""
        if key in self.key_map:
            _, freq, _ = self.key_map[key]
            del self.key_map[key]
            if freq in self.freq_map:
                del self.freq_map[freq][key]
                if not self.freq_map[freq]:
                    del self.freq_map[freq]

    def stats(self):
        """获取统计"""
        total = self.hits + self.misses
        return {
            'size': len(self.key_map),
            'max_size': self.max_size,
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': self.hits / total if total > 0 else 0,
        }

三、缓存键生成

3.1 内容哈希

缓存键的质量直接决定缓存的命中率。

好的缓存键:

  • 相同输入 → 相同键(确定性)
  • 不同输入 → 不同键(唯一性)
  • 生成速度快(不能成为瓶颈)

坏的缓存键:

  • 用时间戳作为键 → 永远不会命中
  • 用随机数作为键 → 永远不会命中
  • 用文件路径作为键 → 文件内容变了但键没变
class CacheKeyGenerator:
    @staticmethod
    def tensor_hash(tensor):
        """张量内容哈希"""
        return hashlib.sha256(
            tensor.detach().cpu().contiguous().numpy().tobytes()
        ).hexdigest()

    @staticmethod
    def batch_hash(tensors):
        """批量张量哈希"""
        combined = b""
        for t in tensors:
            combined += t.detach().cpu().contiguous().numpy().tobytes()
        return hashlib.sha256(combined).hexdigest()

    @staticmethod
    def text_hash(text):
        """文本哈希"""
        return hashlib.sha256(text.encode('utf-8')).hexdigest()

    @staticmethod
    def params_hash(**kwargs):
        """参数组合哈希"""
        sorted_params = sorted(kwargs.items())
        param_str = str(sorted_params)
        return hashlib.sha256(param_str.encode()).hexdigest()

使用示例:

key_gen = CacheKeyGenerator()

# 图像推理缓存键
image = torch.randn(1, 3, 224, 224)
key = key_gen.tensor_hash(image)
print(f"图像缓存键: {key[:16]}...")

# 文本推理缓存键
text = "今天天气真好"
key = key_gen.text_hash(text)
print(f"文本缓存键: {key[:16]}...")

# 带参数的缓存键
key = key_gen.params_hash(
    model="resnet50",
    image_hash=key_gen.tensor_hash(image),
    threshold=0.5,
)
print(f"带参数缓存键: {key[:16]}...")

四、完整推理缓存系统

架构流程:

推理请求 → 缓存键生成 → LRU 缓存
                        ├─ 命中 → 直接返回结果
                        └─ 未命中 → NPU 推理 → 结果存入缓存 → 返回结果
class InferenceCacheSystem:
    def __init__(self, model, cache_size=1000, ttl_seconds=300):
        self.model = model
        self.cache = LRUCache(max_size=cache_size, ttl_seconds=ttl_seconds)
        self.key_gen = CacheKeyGenerator()
        self.stats_log = []

    def infer(self, input_tensor, model_params=None):
        """带缓存的推理"""
        # 1. 生成缓存键
        if model_params:
            cache_key = self.key_gen.params_hash(
                input_hash=self.key_gen.tensor_hash(input_tensor),
                model_hash=self.key_gen.text_hash(str(model_params)),
            )
        else:
            cache_key = self.key_gen.tensor_hash(input_tensor)

        # 2. 查缓存
        cached = self.cache.get(cache_key)
        if cached is not None:
            return cached, "hit"

        # 3. 执行推理
        start_time = time.time()
        with torch.no_grad():
            result = self.model(input_tensor.npu()).cpu()
        infer_time = (time.time() - start_time) * 1000

        # 4. 存入缓存
        self.cache.put(cache_key, result)

        # 5. 记录统计
        self.stats_log.append({
            'cache_key': cache_key[:16],
            'status': 'miss',
            'infer_time_ms': infer_time,
            'timestamp': time.time(),
        })

        return result, "miss"

    def batch_infer(self, input_tensors):
        """批量推理(自动缓存)"""
        results = []
        hit_count = 0
        miss_count = 0

        for inp in input_tensors:
            result, status = self.infer(inp)
            results.append(result)
            if status == "hit":
                hit_count += 1
            else:
                miss_count += 1

        return results, {
            'total': len(input_tensors),
            'hits': hit_count,
            'misses': miss_count,
            'hit_rate': hit_count / len(input_tensors),
        }

    def invalidate(self, pattern=None):
        """使缓存失效"""
        if pattern:
            keys_to_remove = [k for k in self.cache.cache.keys() if pattern in k]
            for key in keys_to_remove:
                del self.cache.cache[key]
        else:
            self.cache.clear()

    def print_stats(self):
        """打印统计"""
        stats = self.cache.stats()
        print(f"
缓存统计:")
        print(f"  大小: {stats['size']}/{stats['max_size']}")
        print(f"  命中: {stats['hits']}")
        print(f"  未命中: {stats['misses']}")
        print(f"  命中率: {stats['hit_rate']:.1%}")

五、常见问题

问题 原因 解决方案
命中率低 缓存键太精确或输入几乎不重复 放宽缓存键(如量化后再哈希)
内存占用高 缓存太大或结果太大 减小缓存大小、压缩缓存结果
缓存不一致 模型更新后旧缓存未失效 设置 TTL、版本化缓存键
并发性能差 锁竞争激烈 分片缓存、无锁数据结构
缓存雪崩 大量缓存同时过期 添加随机 TTL 偏移

相关仓库

  • CANN - 昇腾异构计算架构 https://atomgit.com/cann
  • cann-recipes-infer - 推理配方 https://atomgit.com/cann/ops-nn
  • ops-nn - 神经网络算子库 https://atomgit.com/cann/ops-nn
  • driver - 驱动 https://atomgit.com/cann/driver
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐