CANN 推理缓存:相同输入的秒级响应实战
·
一、为什么需要推理缓存
1.1 重复请求的浪费
现实中的推理服务有一个被忽视的事实:大量请求是重复的。
以智慧安防为例:
- 同一个摄像头,24 小时监控同一个入口
- 每秒 25 帧,大部分帧的画面几乎一样
- 对同一张人脸的识别结果,每秒被计算 25 次
- 实际上,每秒只需要计算 1-2 次,其余 23 次是浪费
以电商搜索为例:
- 用户搜索"手机壳",系统返回推荐结果
- 10 秒后,另一个用户也搜索"手机壳"
- 两次搜索的输入完全相同,但第二次重新计算了整个模型
量化分析:
假设推理延迟 10ms,QPS 1000,重复请求比例 30%。
- 不用缓存:每秒计算量 1000 次,总计算时间 1000 × 10ms = 10 秒(需要多个 NPU 并行)
- 用缓存:每秒实际计算 700 次(300 次命中缓存),缓存命中延迟 < 0.1ms(内存查找),节省 30% 的 NPU 算力
1.2 缓存的代价
缓存不是免费的,需要权衡:
收益:
- 减少 NPU 计算量
- 降低延迟(缓存命中 < 0.1ms vs 推理 10ms)
- 提高吞吐量
- 节省算力成本
代价:
- 内存占用(缓存需要存储结果)
- 一致性风险(模型更新后缓存可能过期)
- 实现复杂度(缓存键生成、失效策略、并发控制)
- 首次请求仍然慢(缓存未命中)
什么时候适合用缓存:
- 输入空间有限(如固定类别分类)
- 重复请求比例高
- 对延迟敏感
- 模型更新不频繁
什么时候不适合:
- 输入几乎不重复(如随机生成的内容)
- 内存非常紧张
- 模型频繁更新(缓存频繁失效)
二、缓存策略
2.1 LRU 缓存
核心思想: 当缓存满了,淘汰最久没有被访问的条目。
为什么选 LRU?
- 实现简单,性能好
- 符合"时间局部性"原理:最近被访问的数据,很可能很快再次被访问
- 适用于大多数推理场景
数据结构: 使用 OrderedDict 维护访问顺序。每次访问时,将条目移到末尾(最新)。淘汰时,删除头部(最旧)的条目。
时间复杂度: 查找 O(1)、插入 O(1)、淘汰 O(1)。
为什么不用 list? list 的删除是 O(n),当缓存很大时会变慢。OrderedDict 的删除是 O(1)。
import hashlib
import time
import threading
from collections import OrderedDict
class LRUCache:
def __init__(self, max_size=1000, ttl_seconds=300):
"""
参数:
max_size: 缓存最大条目数
太小: 命中率低
太大: 内存占用高
经验值: 预估重复请求数的 2-5 倍
ttl_seconds: 缓存条目过期时间(秒)
为什么要过期?
1. 防止内存无限增长
2. 模型更新后旧结果自动失效
3. 控制缓存的新鲜度
TTL 的选择:
- 模型不更新: 可以设很长(如 1 小时)
- 模型频繁更新: 设很短(如 10 秒)
- 一般场景: 5-10 分钟
"""
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.cache = OrderedDict() # {key: (value, timestamp)}
self.lock = threading.RLock() # 线程安全锁
self.hits = 0
self.misses = 0
def get(self, key):
"""获取缓存"""
with self.lock:
if key in self.cache:
value, timestamp = self.cache[key]
# 检查是否过期
if time.time() - timestamp < self.ttl_seconds:
# 命中: 移到末尾(标记为最近使用)
self.cache.move_to_end(key)
self.hits += 1
return value
else:
# 过期: 删除
del self.cache[key]
self.misses += 1
return None
def put(self, key, value):
"""放入缓存"""
with self.lock:
if key in self.cache:
# 更新已有条目
self.cache.move_to_end(key)
self.cache[key] = (value, time.time())
else:
# 检查是否需要淘汰
if len(self.cache) >= self.max_size:
# 删除最久未使用的(头部)
self.cache.popitem(last=False)
# 插入新条目
self.cache[key] = (value, time.time())
def clear(self):
"""清空缓存"""
with self.lock:
self.cache.clear()
self.hits = 0
self.misses = 0
def stats(self):
"""获取缓存统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
'size': len(self.cache),
'max_size': self.max_size,
'hits': self.hits,
'misses': self.misses,
'hit_rate': hit_rate,
}
使用示例:
cache = LRUCache(max_size=1000, ttl_seconds=300)
def infer_with_cache(model, input_tensor):
"""带缓存的推理"""
# 生成缓存键(基于输入内容的哈希)
cache_key = hashlib.md5(input_tensor.numpy().tobytes()).hexdigest()
# 尝试从缓存获取
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result, "cache_hit"
# 缓存未命中,执行推理
with torch.no_grad():
result = model(input_tensor.npu()).cpu()
# 存入缓存
cache.put(cache_key, result)
return result, "cache_miss"
# 模拟 1000 个请求(30% 重复)
import random
inputs = [torch.randn(1, 3, 224, 224) for _ in range(700)]
# 加入 300 个重复请求
for _ in range(300):
inputs.append(random.choice(inputs))
random.shuffle(inputs)
for inp in inputs:
result, status = infer_with_cache(model, inp)
print(f"缓存统计: {cache.stats()}")
# 期望命中率: ~30%
2.2 LFU 缓存
核心思想: 当缓存满了,淘汰访问频率最低的条目。
与 LRU 的区别:
- LRU:淘汰"最久没访问的"
- LFU:淘汰"访问次数最少的"
适用场景:
- 某些输入被反复请求(如热门商品的识别)
- 希望保留高频条目,淘汰低频条目
实现方式: 使用两个数据结构——freq_map(频率 → 该频率的所有 key)和 key_map(key → 值和频率)。
class LFUCache:
def __init__(self, max_size=1000, ttl_seconds=300):
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.key_map = {} # {key: (value, freq, timestamp)}
self.freq_map = {} # {freq: OrderedDict({key: None})}
self.min_freq = 0
self.lock = threading.RLock()
self.hits = 0
self.misses = 0
def get(self, key):
"""获取缓存并增加频率"""
with self.lock:
if key in self.key_map:
value, freq, timestamp = self.key_map[key]
if time.time() - timestamp < self.ttl_seconds:
self._increase_freq(key, freq)
self.hits += 1
return value
else:
self._remove(key)
self.misses += 1
return None
def put(self, key, value):
"""放入缓存"""
with self.lock:
if key in self.key_map:
_, freq, _ = self.key_map[key]
self.key_map[key] = (value, freq + 1, time.time())
self._increase_freq(key, freq)
else:
if len(self.key_map) >= self.max_size:
self._evict()
self.key_map[key] = (value, 1, time.time())
self.min_freq = 1
if 1 not in self.freq_map:
self.freq_map[1] = OrderedDict()
self.freq_map[1][key] = None
def _increase_freq(self, key, old_freq):
"""增加 key 的频率"""
del self.freq_map[old_freq][key]
if not self.freq_map[old_freq]:
del self.freq_map[old_freq]
if self.min_freq == old_freq:
self.min_freq = old_freq + 1
new_freq = old_freq + 1
if new_freq not in self.freq_map:
self.freq_map[new_freq] = OrderedDict()
self.freq_map[new_freq][key] = None
value, _, timestamp = self.key_map[key]
self.key_map[key] = (value, new_freq, timestamp)
def _evict(self):
"""淘汰频率最低的条目"""
if self.min_freq in self.freq_map and self.freq_map[self.min_freq]:
key, _ = self.freq_map[self.min_freq].popitem(last=False)
if not self.freq_map[self.min_freq]:
del self.freq_map[self.min_freq]
del self.key_map[key]
def _remove(self, key):
"""删除指定 key"""
if key in self.key_map:
_, freq, _ = self.key_map[key]
del self.key_map[key]
if freq in self.freq_map:
del self.freq_map[freq][key]
if not self.freq_map[freq]:
del self.freq_map[freq]
def stats(self):
"""获取统计"""
total = self.hits + self.misses
return {
'size': len(self.key_map),
'max_size': self.max_size,
'hits': self.hits,
'misses': self.misses,
'hit_rate': self.hits / total if total > 0 else 0,
}
三、缓存键生成
3.1 内容哈希
缓存键的质量直接决定缓存的命中率。
好的缓存键:
- 相同输入 → 相同键(确定性)
- 不同输入 → 不同键(唯一性)
- 生成速度快(不能成为瓶颈)
坏的缓存键:
- 用时间戳作为键 → 永远不会命中
- 用随机数作为键 → 永远不会命中
- 用文件路径作为键 → 文件内容变了但键没变
class CacheKeyGenerator:
@staticmethod
def tensor_hash(tensor):
"""张量内容哈希"""
return hashlib.sha256(
tensor.detach().cpu().contiguous().numpy().tobytes()
).hexdigest()
@staticmethod
def batch_hash(tensors):
"""批量张量哈希"""
combined = b""
for t in tensors:
combined += t.detach().cpu().contiguous().numpy().tobytes()
return hashlib.sha256(combined).hexdigest()
@staticmethod
def text_hash(text):
"""文本哈希"""
return hashlib.sha256(text.encode('utf-8')).hexdigest()
@staticmethod
def params_hash(**kwargs):
"""参数组合哈希"""
sorted_params = sorted(kwargs.items())
param_str = str(sorted_params)
return hashlib.sha256(param_str.encode()).hexdigest()
使用示例:
key_gen = CacheKeyGenerator()
# 图像推理缓存键
image = torch.randn(1, 3, 224, 224)
key = key_gen.tensor_hash(image)
print(f"图像缓存键: {key[:16]}...")
# 文本推理缓存键
text = "今天天气真好"
key = key_gen.text_hash(text)
print(f"文本缓存键: {key[:16]}...")
# 带参数的缓存键
key = key_gen.params_hash(
model="resnet50",
image_hash=key_gen.tensor_hash(image),
threshold=0.5,
)
print(f"带参数缓存键: {key[:16]}...")
四、完整推理缓存系统
架构流程:
推理请求 → 缓存键生成 → LRU 缓存
├─ 命中 → 直接返回结果
└─ 未命中 → NPU 推理 → 结果存入缓存 → 返回结果
class InferenceCacheSystem:
def __init__(self, model, cache_size=1000, ttl_seconds=300):
self.model = model
self.cache = LRUCache(max_size=cache_size, ttl_seconds=ttl_seconds)
self.key_gen = CacheKeyGenerator()
self.stats_log = []
def infer(self, input_tensor, model_params=None):
"""带缓存的推理"""
# 1. 生成缓存键
if model_params:
cache_key = self.key_gen.params_hash(
input_hash=self.key_gen.tensor_hash(input_tensor),
model_hash=self.key_gen.text_hash(str(model_params)),
)
else:
cache_key = self.key_gen.tensor_hash(input_tensor)
# 2. 查缓存
cached = self.cache.get(cache_key)
if cached is not None:
return cached, "hit"
# 3. 执行推理
start_time = time.time()
with torch.no_grad():
result = self.model(input_tensor.npu()).cpu()
infer_time = (time.time() - start_time) * 1000
# 4. 存入缓存
self.cache.put(cache_key, result)
# 5. 记录统计
self.stats_log.append({
'cache_key': cache_key[:16],
'status': 'miss',
'infer_time_ms': infer_time,
'timestamp': time.time(),
})
return result, "miss"
def batch_infer(self, input_tensors):
"""批量推理(自动缓存)"""
results = []
hit_count = 0
miss_count = 0
for inp in input_tensors:
result, status = self.infer(inp)
results.append(result)
if status == "hit":
hit_count += 1
else:
miss_count += 1
return results, {
'total': len(input_tensors),
'hits': hit_count,
'misses': miss_count,
'hit_rate': hit_count / len(input_tensors),
}
def invalidate(self, pattern=None):
"""使缓存失效"""
if pattern:
keys_to_remove = [k for k in self.cache.cache.keys() if pattern in k]
for key in keys_to_remove:
del self.cache.cache[key]
else:
self.cache.clear()
def print_stats(self):
"""打印统计"""
stats = self.cache.stats()
print(f"
缓存统计:")
print(f" 大小: {stats['size']}/{stats['max_size']}")
print(f" 命中: {stats['hits']}")
print(f" 未命中: {stats['misses']}")
print(f" 命中率: {stats['hit_rate']:.1%}")
五、常见问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 命中率低 | 缓存键太精确或输入几乎不重复 | 放宽缓存键(如量化后再哈希) |
| 内存占用高 | 缓存太大或结果太大 | 减小缓存大小、压缩缓存结果 |
| 缓存不一致 | 模型更新后旧缓存未失效 | 设置 TTL、版本化缓存键 |
| 并发性能差 | 锁竞争激烈 | 分片缓存、无锁数据结构 |
| 缓存雪崩 | 大量缓存同时过期 | 添加随机 TTL 偏移 |
相关仓库
- CANN - 昇腾异构计算架构 https://atomgit.com/cann
- cann-recipes-infer - 推理配方 https://atomgit.com/cann/ops-nn
- ops-nn - 神经网络算子库 https://atomgit.com/cann/ops-nn
- driver - 驱动 https://atomgit.com/cann/driver
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)