FlashAttention正确性基准:算法对拍与数值精度对照表
·
某团队在昇腾NPU上部署FlashAttention后,性能提升了8倍,但业务方报告:"模型输出变了一点,虽然不多,但不允许。"团队对比了NPU输出和GPU输出,发现结果有微小差异——最大相对误差约1e-3。他们不确定:这1e-3是正常的数值误差,还是bug?
问题出在没有建立正确性基准。FlashAttention相比标准Attention引入了tiling、在线softmax、近似计算等多个环节,每个环节都有精度损失的可能。需要一个系统的基准测试框架,区分正常误差范围和异常bug。
今天把FlashAttention正确性基准的建立方法和数值精度对照表讲清楚。
正确性基准的设计
分层测试框架
import torch
import numpy as np
import math
from typing import Dict, List, Tuple
class FlashAttentionCorrectnessBenchmark:
"""
FlashAttention正确性基准测试
目标:
1. 验证算法正确性(与标准Attention对比)
2. 建立数值精度对照表(不同配置下的误差范围)
3. 区分正常误差和bug
"""
def __init__(self, rtol=1e-3, atol=1e-6):
self.rtol = rtol
self.atol = atol
def reference_attention(self, q, k, v, mask=None):
"""
标准Attention(参考实现)
用于对比的正确性基准
"""
d_k = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores + mask
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
return output
def flash_attention(self, q, k, v, block_size=128):
"""
FlashAttention实现
使用在线softmax + tiling
"""
# 简化实现(实际使用CANN算子)
B, H, S, D = q.shape
scale = 1.0 / math.sqrt(D)
# 分块计算
output = torch.zeros_like(q)
for i in range(0, S, block_size):
# 当前块的Q
q_block = q[:, :, i:i+block_size, :]
# 计算与所有K的attention(分块)
m_i = torch.full((B, H, block_size), -float('inf'), device=q.device)
l_i = torch.zeros(B, H, block_size, device=q.device)
o_i = torch.zeros_like(q_block)
for j in range(0, S, block_size):
# K、V的块
k_block = k[:, :, j:j+block_size, :]
v_block = v[:, :, j:j+block_size, :]
# 计算这个块的scores
s_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
# 在线softmax更新
m_old = m_i
m_new = torch.maximum(m_old, s_block.amax(dim=-1, keepdim=True))
# 更新
alpha = torch.exp(m_old - m_new)
alpha_s = torch.exp(s_block - m_new)
l_new = alpha * l_i + alpha_s.sum(dim=-1)
o_i = (alpha * l_i.unsqueeze(-1) * o_i +
torch.matmul(alpha_s, v_block))
m_i = m_new
l_i = l_new
# 归一化
output[:, :, i:i+block_size, :] = o_i / l_i.unsqueeze(-1)
return output
class CorrectnessTestSuite:
"""
正确性测试套件
覆盖各种边界情况和数值稳定性
"""
def __init__(self):
self.tests = []
def run_all_tests(self, flash_attention_fn):
"""运行所有测试"""
results = {}
tests = [
("Basic", self.test_basic_correctness),
("NumericalPrecision", self.test_numerical_precision),
("BoundaryConditions", self.test_boundary_conditions),
("Overflow", self.test_overflow_handling),
("Underflow", self.test_underflow_handling),
("Masked", self.test_masked_attention),
("LongSequence", self.test_long_sequence),
("ExtremeValues", self.test_extreme_values),
]
print("\n=== FlashAttention正确性测试套件 ===\n")
for name, test_fn in tests:
try:
passed, details = test_fn(flash_attention_fn)
results[name] = {"passed": passed, "details": details}
status = "✅ PASS" if passed else "❌ FAIL"
print(f" {name:<25} {status}")
if details:
print(f" {details}")
except Exception as e:
results[name] = {"passed": False, "error": str(e)}
print(f" {name:<25} ❌ ERROR: {e}")
total = len(results)
passed = sum(1 for r in results.values() if r["passed"])
print(f"\n结果: {passed}/{total} 通过")
return results
def test_basic_correctness(self, fn):
"""基本正确性测试"""
torch.manual_seed(42)
B, H, S, D = 2, 4, 128, 64
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
max_diff = (ref_out - test_out).abs().max().item()
rtol = (ref_out - test_out).abs() / (ref_out.abs() + 1e-8)
max_rtol = rtol.max().item()
passed = max_diff < 1e-4 and max_rtol < 1e-2
return passed, f"max_diff={max_diff:.2e}, max_rtol={max_rtol:.2e}"
def test_numerical_precision(self, fn):
"""数值精度测试"""
torch.manual_seed(123)
B, H, S, D = 2, 8, 512, 128
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
# 统计误差分布
diff = (ref_out - test_out).abs()
metrics = {
"max_abs_error": diff.max().item(),
"mean_abs_error": diff.mean().item(),
"median_abs_error": diff.median().item(),
"p99_abs_error": torch.quantile(diff, 0.99).item(),
}
# 昇腾NPU预期误差范围
passed = metrics["max_abs_error"] < 1e-2
details = f"max={metrics['max_abs_error']:.2e}, mean={metrics['mean_abs_error']:.2e}"
return passed, details
def test_boundary_conditions(self, fn):
"""边界条件测试"""
tests = [
("S=1", (1, 1, 1, 64)),
("S=2", (1, 1, 2, 64)),
("S=block_size", (1, 1, 128, 64)),
("S=block_size+1", (1, 1, 129, 64)),
("S=2*block_size", (1, 1, 256, 64)),
("D=1", (1, 1, 64, 1)),
("D=256", (1, 1, 64, 256)),
("H=1", (1, 1, 64, 64)),
("H=128", (1, 128, 64, 64)),
("B=1", (1, 8, 128, 64)),
("B=32", (32, 8, 128, 64)),
]
all_passed = True
failed_cases = []
for name, shape in tests:
try:
B, H, S, D = shape
q = torch.randn(*shape)
k = torch.randn(*shape)
v = torch.randn(*shape)
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
diff = (ref_out - test_out).abs().max().item()
if diff > 1e-3:
all_passed = False
failed_cases.append(f"{name}(diff={diff:.2e})")
except Exception as e:
all_passed = False
failed_cases.append(f"{name}(error={e})")
details = f"失败: {failed_cases}" if failed_cases else "全部通过"
return all_passed, details
def test_overflow_handling(self, fn):
"""溢出处理测试"""
torch.manual_seed(42)
B, H, S, D = 2, 4, 128, 64
# 生成大的Q、K值(可能导致溢出)
q = torch.randn(B, H, S, D) * 10
k = torch.randn(B, H, S, D) * 10
v = torch.randn(B, H, S, D)
try:
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
# 检查是否有NaN/Inf
ref_valid = torch.isfinite(ref_out).all().item()
test_valid = torch.isfinite(test_out).all().item()
if not test_valid:
return False, "测试输出包含NaN/Inf"
if ref_valid:
# 对比精度
diff = (ref_out - test_out).abs().max().item()
passed = diff < 1e-2
return passed, f"max_diff={diff:.2e}"
else:
# 参考实现也溢出,测试通过
return True, "参考实现溢出(预期行为)"
except Exception as e:
return False, f"异常: {e}"
def test_underflow_handling(self, fn):
"""下溢处理测试"""
torch.manual_seed(42)
B, H, S, D = 2, 4, 128, 64
# 生成非常小的值
q = torch.randn(B, H, S, D) * 0.01
k = torch.randn(B, H, S, D) * 0.01
v = torch.randn(B, H, S, D)
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
diff = (ref_out - test_out).abs().max().item()
passed = diff < 1e-3
return passed, f"max_diff={diff:.2e}"
def test_masked_attention(self, fn):
"""Masked Attention测试"""
torch.manual_seed(42)
B, H, S, D = 2, 4, 128, 64
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
# 因果mask
mask = torch.triu(torch.ones(S, S), diagonal=1) * -1e9
ref_out = self.reference_attention(q, k, v, mask)
test_out = fn(q, k, v) # 假设FlashAttention原生支持causal
diff = (ref_out - test_out).abs().max().item()
passed = diff < 1e-4
return passed, f"max_diff={diff:.2e}"
def test_long_sequence(self, fn):
"""长序列测试"""
torch.manual_seed(42)
B, H, S, D = 1, 8, 8192, 128
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
ref_out = self.reference_attention(q, k, v)
test_out = fn(q, k, v)
diff = (ref_out - test_out).abs().max().item()
passed = diff < 1e-3
return passed, f"max_diff={diff:.2e}"
def test_extreme_values(self, fn):
"""极端值测试"""
torch.manual_seed(42)
B, H, S, D = 2, 4, 64, 64
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
# 测试各种极端值
extreme_cases = [
("all_zeros", lambda: (torch.zeros_like(q), k, v)),
("all_ones", lambda: (torch.ones_like(q), k, v)),
("uniform", lambda: (torch.ones_like(q) * 0.5, k, v)),
("alternating", lambda: (torch.tensor([1.0, -1.0]).expand(B, H, S, D), k, v)),
]
all_passed = True
failed = []
for name, getter in extreme_cases:
try:
q_e, k_e, v_e = getter()
test_out = fn(q_e, k_e, v_e)
if not torch.isfinite(test_out).all():
all_passed = False
failed.append(name)
except Exception as e:
all_passed = False
failed.append(f"{name}({e})")
details = f"失败: {failed}" if failed else "全部通过"
return all_passed, details
def reference_attention(self, q, k, v, mask=None):
"""参考实现"""
return FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v, mask)
数值精度对照表
def generate_precision_table():
"""
生成数值精度对照表
不同配置下的预期误差范围
"""
print("\n=== FlashAttention数值精度对照表 ===")
print("参考:标准Attention vs FlashAttention(昇腾NPU)")
print(f"\n{'配置':<40} | {'Max AE':>10} | {'Mean AE':>10} | {'Median':>10} | {'P99':>10}")
print("-" * 90)
configs = [
{"name": "FP16, D=64, S=512", "dtype": "fp16", "D": 64, "S": 512, "head_dim": 64},
{"name": "FP16, D=64, S=4096", "dtype": "fp16", "D": 64, "S": 4096, "head_dim": 64},
{"name": "FP16, D=128, S=512", "dtype": "fp16", "D": 128, "S": 512, "head_dim": 128},
{"name": "FP16, D=128, S=4096", "dtype": "fp16", "D": 128, "S": 4096, "head_dim": 128},
{"name": "BF16, D=64, S=4096", "dtype": "bf16", "D": 64, "S": 4096, "head_dim": 64},
{"name": "FP32, D=64, S=4096", "dtype": "fp32", "D": 64, "S": 4096, "head_dim": 64},
]
import random
random.seed(42)
for cfg in configs:
# 模拟精度数据
base_error = 1e-6
if cfg["dtype"] == "fp16":
base_error = 5e-4
elif cfg["dtype"] == "bf16":
base_error = 2e-4
elif cfg["dtype"] == "fp32":
base_error = 1e-7
# 序列长度影响
s_factor = (cfg["S"] / 512) ** 0.3
max_ae = base_error * s_factor * (1 + random.random() * 0.5)
mean_ae = max_ae * 0.1
median_ae = max_ae * 0.05
p99 = max_ae * 2
name = cfg["name"]
print(f"{name:<40} | {max_ae:>10.2e} | {mean_ae:>10.2e} | {median_ae:>10.2e} | {p99:>10.2e}")
print("\n=== 误差来源分析 ===")
sources = [
{"source": "在线softmax累积误差", "fp16_impact": "高", "bf16_impact": "低", "fp32_impact": "可忽略"},
{"source": "Block边界精度损失", "fp16_impact": "中", "bf16_impact": "低", "fp32_impact": "可忽略"},
{"source": "数值稳定性(m-exp)", "fp16_impact": "高", "bf16_impact": "中", "fp32_impact": "可忽略"},
{"source": "矩阵乘法累加精度", "fp16_impact": "中", "bf16_impact": "中", "fp32_impact": "可忽略"},
{"source": "SRAM限幅(RMS)", "fp16_impact": "低", "bf16_impact": "低", "fp32_impact": "无"},
]
print(f"\n{'误差来源':<25} | {'FP16':>10} | {'BF16':>10} | {'FP32':>10}")
print("-" * 60)
for s in sources:
print(f"{s['source']:<25} | {s['fp16_impact']:>10} | {s['bf16_impact']:>10} | {s['fp32_impact']:>10}")
print("\n=== 判断标准 ===")
criteria = [
{"condition": "Max AE < 1e-3", "interpretation": "正常误差范围", "action": "无需处理"},
{"condition": "Max AE 1e-3 ~ 1e-2", "interpretation": "轻微偏差", "action": "监控观察"},
{"condition": "Max AE 1e-2 ~ 1e-1", "interpretation": "明显偏差", "action": "检查配置"},
{"condition": "Max AE > 1e-1", "interpretation": "严重问题", "action": "立即排查"},
{"condition": "出现NaN/Inf", "interpretation": "数值溢出", "action": "立即修复"},
]
print(f"\n{'条件':<20} | {'解读':<20} | {'建议行动':<20}")
print("-" * 65)
for c in criteria:
print(f"{c['condition']:<20} | {c['interpretation']:<20} | {c['action']:<20}")
自动回归测试
class RegressionTestRunner:
"""
回归测试运行器
每次代码变更后自动运行基准测试
确保不引入新的误差
"""
def __init__(self, baseline_path=".flash_attention_baseline"):
self.baseline_path = baseline_path
self.baseline = None
def save_baseline(self, flash_attn_fn, test_cases):
"""保存基准"""
print("=== 保存基准数据 ===")
baseline = {}
for case_name, (q, k, v) in test_cases.items():
ref_out = FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v)
flash_out = flash_attn_fn(q, k, v)
baseline[case_name] = {
"ref_output": ref_out,
"flash_output": flash_out,
"expected_diff": (ref_out - flash_out).abs().max().item()
}
print(f" {case_name}: 预期误差 = {baseline[case_name]['expected_diff']:.2e}")
self.baseline = baseline
# 保存
torch.save(baseline, self.baseline_path)
print(f"\n✅ 基准数据已保存至 {self.baseline_path}")
def run_regression(self, flash_attn_fn, test_cases):
"""运行回归测试"""
print("\n=== 回归测试 ===")
if self.baseline is None:
if not os.path.exists(self.baseline_path):
print("❌ 未找到基准数据,请先运行 save_baseline")
return
self.baseline = torch.load(self.baseline_path)
all_passed = True
for case_name, (q, k, v) in test_cases.items():
if case_name not in self.baseline:
print(f" ⚠️ {case_name}: 无基准数据,跳过")
continue
ref_out = FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v)
flash_out = flash_attn_fn(q, k, v)
current_diff = (ref_out - flash_out).abs().max().item()
baseline_diff = self.baseline[case_name]["expected_diff"]
# 允许10%的误差增长
tolerance = baseline_diff * 1.1 + 1e-6
passed = current_diff <= tolerance
if passed:
print(f" ✅ {case_name}: {current_diff:.2e} (基准: {baseline_diff:.2e})")
else:
print(f" ❌ {case_name}: {current_diff:.2e} (基准: {baseline_diff:.2e}) [超限]")
all_passed = False
if all_passed:
print("\n✅ 回归测试全部通过")
else:
print("\n❌ 回归测试失败,请检查变更")
return all_passed
总结:正确性基准配置清单
| 测试类型 | 验收标准 | 超限处理 |
|---|---|---|
| 基本正确性 | Max AE < 1e-4 | 立即排查 |
| 数值精度 | Max AE < 1e-2 | 监控 |
| 边界条件 | 全部通过 | 立即修复 |
| 溢出处理 | 无NaN/Inf | 立即修复 |
| 长序列 | Max AE < 1e-3 | 检查 |
| 极端值 | 全部有效 | 立即修复 |
| 回归测试 | 误差不增长>10% | 审查变更 |
建议:
- 每次部署前运行完整测试套件
- 每次代码变更前保存基准
- 监控生产环境的数值精度
代码和文档:
https://atomgit.com/cann/ops-transformer
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)