某团队在昇腾NPU上部署FlashAttention后,性能提升了8倍,但业务方报告:"模型输出变了一点,虽然不多,但不允许。"团队对比了NPU输出和GPU输出,发现结果有微小差异——最大相对误差约1e-3。他们不确定:这1e-3是正常的数值误差,还是bug?

问题出在没有建立正确性基准。FlashAttention相比标准Attention引入了tiling、在线softmax、近似计算等多个环节,每个环节都有精度损失的可能。需要一个系统的基准测试框架,区分正常误差范围和异常bug。

今天把FlashAttention正确性基准的建立方法和数值精度对照表讲清楚。

正确性基准的设计

分层测试框架

import torch
import numpy as np
import math
from typing import Dict, List, Tuple

class FlashAttentionCorrectnessBenchmark:
    """
    FlashAttention正确性基准测试
    
    目标:
      1. 验证算法正确性(与标准Attention对比)
      2. 建立数值精度对照表(不同配置下的误差范围)
      3. 区分正常误差和bug
    """
    
    def __init__(self, rtol=1e-3, atol=1e-6):
        self.rtol = rtol
        self.atol = atol
    
    def reference_attention(self, q, k, v, mask=None):
        """
        标准Attention(参考实现)
        
        用于对比的正确性基准
        """
        
        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            scores = scores + mask
        
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)
        
        return output
    
    def flash_attention(self, q, k, v, block_size=128):
        """
        FlashAttention实现
        
        使用在线softmax + tiling
        """
        
        # 简化实现(实际使用CANN算子)
        B, H, S, D = q.shape
        
        scale = 1.0 / math.sqrt(D)
        
        # 分块计算
        output = torch.zeros_like(q)
        
        for i in range(0, S, block_size):
            # 当前块的Q
            q_block = q[:, :, i:i+block_size, :]
            
            # 计算与所有K的attention(分块)
            m_i = torch.full((B, H, block_size), -float('inf'), device=q.device)
            l_i = torch.zeros(B, H, block_size, device=q.device)
            o_i = torch.zeros_like(q_block)
            
            for j in range(0, S, block_size):
                # K、V的块
                k_block = k[:, :, j:j+block_size, :]
                v_block = v[:, :, j:j+block_size, :]
                
                # 计算这个块的scores
                s_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
                
                # 在线softmax更新
                m_old = m_i
                m_new = torch.maximum(m_old, s_block.amax(dim=-1, keepdim=True))
                
                # 更新
                alpha = torch.exp(m_old - m_new)
                alpha_s = torch.exp(s_block - m_new)
                
                l_new = alpha * l_i + alpha_s.sum(dim=-1)
                o_i = (alpha * l_i.unsqueeze(-1) * o_i + 
                       torch.matmul(alpha_s, v_block))
                
                m_i = m_new
                l_i = l_new
            
            # 归一化
            output[:, :, i:i+block_size, :] = o_i / l_i.unsqueeze(-1)
        
        return output


class CorrectnessTestSuite:
    """
    正确性测试套件
    
    覆盖各种边界情况和数值稳定性
    """
    
    def __init__(self):
        self.tests = []
    
    def run_all_tests(self, flash_attention_fn):
        """运行所有测试"""
        
        results = {}
        
        tests = [
            ("Basic", self.test_basic_correctness),
            ("NumericalPrecision", self.test_numerical_precision),
            ("BoundaryConditions", self.test_boundary_conditions),
            ("Overflow", self.test_overflow_handling),
            ("Underflow", self.test_underflow_handling),
            ("Masked", self.test_masked_attention),
            ("LongSequence", self.test_long_sequence),
            ("ExtremeValues", self.test_extreme_values),
        ]
        
        print("\n=== FlashAttention正确性测试套件 ===\n")
        
        for name, test_fn in tests:
            try:
                passed, details = test_fn(flash_attention_fn)
                results[name] = {"passed": passed, "details": details}
                
                status = "✅ PASS" if passed else "❌ FAIL"
                print(f"  {name:<25} {status}")
                if details:
                    print(f"    {details}")
            except Exception as e:
                results[name] = {"passed": False, "error": str(e)}
                print(f"  {name:<25} ❌ ERROR: {e}")
        
        total = len(results)
        passed = sum(1 for r in results.values() if r["passed"])
        print(f"\n结果: {passed}/{total} 通过")
        
        return results
    
    def test_basic_correctness(self, fn):
        """基本正确性测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 2, 4, 128, 64
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        ref_out = self.reference_attention(q, k, v)
        test_out = fn(q, k, v)
        
        max_diff = (ref_out - test_out).abs().max().item()
        rtol = (ref_out - test_out).abs() / (ref_out.abs() + 1e-8)
        max_rtol = rtol.max().item()
        
        passed = max_diff < 1e-4 and max_rtol < 1e-2
        
        return passed, f"max_diff={max_diff:.2e}, max_rtol={max_rtol:.2e}"
    
    def test_numerical_precision(self, fn):
        """数值精度测试"""
        
        torch.manual_seed(123)
        
        B, H, S, D = 2, 8, 512, 128
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        ref_out = self.reference_attention(q, k, v)
        test_out = fn(q, k, v)
        
        # 统计误差分布
        diff = (ref_out - test_out).abs()
        
        metrics = {
            "max_abs_error": diff.max().item(),
            "mean_abs_error": diff.mean().item(),
            "median_abs_error": diff.median().item(),
            "p99_abs_error": torch.quantile(diff, 0.99).item(),
        }
        
        # 昇腾NPU预期误差范围
        passed = metrics["max_abs_error"] < 1e-2
        
        details = f"max={metrics['max_abs_error']:.2e}, mean={metrics['mean_abs_error']:.2e}"
        
        return passed, details
    
    def test_boundary_conditions(self, fn):
        """边界条件测试"""
        
        tests = [
            ("S=1", (1, 1, 1, 64)),
            ("S=2", (1, 1, 2, 64)),
            ("S=block_size", (1, 1, 128, 64)),
            ("S=block_size+1", (1, 1, 129, 64)),
            ("S=2*block_size", (1, 1, 256, 64)),
            ("D=1", (1, 1, 64, 1)),
            ("D=256", (1, 1, 64, 256)),
            ("H=1", (1, 1, 64, 64)),
            ("H=128", (1, 128, 64, 64)),
            ("B=1", (1, 8, 128, 64)),
            ("B=32", (32, 8, 128, 64)),
        ]
        
        all_passed = True
        failed_cases = []
        
        for name, shape in tests:
            try:
                B, H, S, D = shape
                q = torch.randn(*shape)
                k = torch.randn(*shape)
                v = torch.randn(*shape)
                
                ref_out = self.reference_attention(q, k, v)
                test_out = fn(q, k, v)
                
                diff = (ref_out - test_out).abs().max().item()
                if diff > 1e-3:
                    all_passed = False
                    failed_cases.append(f"{name}(diff={diff:.2e})")
            except Exception as e:
                all_passed = False
                failed_cases.append(f"{name}(error={e})")
        
        details = f"失败: {failed_cases}" if failed_cases else "全部通过"
        
        return all_passed, details
    
    def test_overflow_handling(self, fn):
        """溢出处理测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 2, 4, 128, 64
        
        # 生成大的Q、K值(可能导致溢出)
        q = torch.randn(B, H, S, D) * 10
        k = torch.randn(B, H, S, D) * 10
        v = torch.randn(B, H, S, D)
        
        try:
            ref_out = self.reference_attention(q, k, v)
            test_out = fn(q, k, v)
            
            # 检查是否有NaN/Inf
            ref_valid = torch.isfinite(ref_out).all().item()
            test_valid = torch.isfinite(test_out).all().item()
            
            if not test_valid:
                return False, "测试输出包含NaN/Inf"
            
            if ref_valid:
                # 对比精度
                diff = (ref_out - test_out).abs().max().item()
                passed = diff < 1e-2
                return passed, f"max_diff={diff:.2e}"
            else:
                # 参考实现也溢出,测试通过
                return True, "参考实现溢出(预期行为)"
                
        except Exception as e:
            return False, f"异常: {e}"
    
    def test_underflow_handling(self, fn):
        """下溢处理测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 2, 4, 128, 64
        
        # 生成非常小的值
        q = torch.randn(B, H, S, D) * 0.01
        k = torch.randn(B, H, S, D) * 0.01
        v = torch.randn(B, H, S, D)
        
        ref_out = self.reference_attention(q, k, v)
        test_out = fn(q, k, v)
        
        diff = (ref_out - test_out).abs().max().item()
        passed = diff < 1e-3
        
        return passed, f"max_diff={diff:.2e}"
    
    def test_masked_attention(self, fn):
        """Masked Attention测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 2, 4, 128, 64
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        # 因果mask
        mask = torch.triu(torch.ones(S, S), diagonal=1) * -1e9
        
        ref_out = self.reference_attention(q, k, v, mask)
        test_out = fn(q, k, v)  # 假设FlashAttention原生支持causal
        
        diff = (ref_out - test_out).abs().max().item()
        passed = diff < 1e-4
        
        return passed, f"max_diff={diff:.2e}"
    
    def test_long_sequence(self, fn):
        """长序列测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 1, 8, 8192, 128
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        ref_out = self.reference_attention(q, k, v)
        test_out = fn(q, k, v)
        
        diff = (ref_out - test_out).abs().max().item()
        passed = diff < 1e-3
        
        return passed, f"max_diff={diff:.2e}"
    
    def test_extreme_values(self, fn):
        """极端值测试"""
        
        torch.manual_seed(42)
        
        B, H, S, D = 2, 4, 64, 64
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        # 测试各种极端值
        extreme_cases = [
            ("all_zeros", lambda: (torch.zeros_like(q), k, v)),
            ("all_ones", lambda: (torch.ones_like(q), k, v)),
            ("uniform", lambda: (torch.ones_like(q) * 0.5, k, v)),
            ("alternating", lambda: (torch.tensor([1.0, -1.0]).expand(B, H, S, D), k, v)),
        ]
        
        all_passed = True
        failed = []
        
        for name, getter in extreme_cases:
            try:
                q_e, k_e, v_e = getter()
                test_out = fn(q_e, k_e, v_e)
                
                if not torch.isfinite(test_out).all():
                    all_passed = False
                    failed.append(name)
            except Exception as e:
                all_passed = False
                failed.append(f"{name}({e})")
        
        details = f"失败: {failed}" if failed else "全部通过"
        
        return all_passed, details
    
    def reference_attention(self, q, k, v, mask=None):
        """参考实现"""
        return FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v, mask)

数值精度对照表

def generate_precision_table():
    """
    生成数值精度对照表
    
    不同配置下的预期误差范围
    """
    
    print("\n=== FlashAttention数值精度对照表 ===")
    print("参考:标准Attention vs FlashAttention(昇腾NPU)")
    
    print(f"\n{'配置':<40} | {'Max AE':>10} | {'Mean AE':>10} | {'Median':>10} | {'P99':>10}")
    print("-" * 90)
    
    configs = [
        {"name": "FP16, D=64, S=512", "dtype": "fp16", "D": 64, "S": 512, "head_dim": 64},
        {"name": "FP16, D=64, S=4096", "dtype": "fp16", "D": 64, "S": 4096, "head_dim": 64},
        {"name": "FP16, D=128, S=512", "dtype": "fp16", "D": 128, "S": 512, "head_dim": 128},
        {"name": "FP16, D=128, S=4096", "dtype": "fp16", "D": 128, "S": 4096, "head_dim": 128},
        {"name": "BF16, D=64, S=4096", "dtype": "bf16", "D": 64, "S": 4096, "head_dim": 64},
        {"name": "FP32, D=64, S=4096", "dtype": "fp32", "D": 64, "S": 4096, "head_dim": 64},
    ]
    
    import random
    random.seed(42)
    
    for cfg in configs:
        # 模拟精度数据
        base_error = 1e-6
        
        if cfg["dtype"] == "fp16":
            base_error = 5e-4
        elif cfg["dtype"] == "bf16":
            base_error = 2e-4
        elif cfg["dtype"] == "fp32":
            base_error = 1e-7
        
        # 序列长度影响
        s_factor = (cfg["S"] / 512) ** 0.3
        
        max_ae = base_error * s_factor * (1 + random.random() * 0.5)
        mean_ae = max_ae * 0.1
        median_ae = max_ae * 0.05
        p99 = max_ae * 2
        
        name = cfg["name"]
        
        print(f"{name:<40} | {max_ae:>10.2e} | {mean_ae:>10.2e} | {median_ae:>10.2e} | {p99:>10.2e}")
    
    print("\n=== 误差来源分析 ===")
    
    sources = [
        {"source": "在线softmax累积误差", "fp16_impact": "高", "bf16_impact": "低", "fp32_impact": "可忽略"},
        {"source": "Block边界精度损失", "fp16_impact": "中", "bf16_impact": "低", "fp32_impact": "可忽略"},
        {"source": "数值稳定性(m-exp)", "fp16_impact": "高", "bf16_impact": "中", "fp32_impact": "可忽略"},
        {"source": "矩阵乘法累加精度", "fp16_impact": "中", "bf16_impact": "中", "fp32_impact": "可忽略"},
        {"source": "SRAM限幅(RMS)", "fp16_impact": "低", "bf16_impact": "低", "fp32_impact": "无"},
    ]
    
    print(f"\n{'误差来源':<25} | {'FP16':>10} | {'BF16':>10} | {'FP32':>10}")
    print("-" * 60)
    
    for s in sources:
        print(f"{s['source']:<25} | {s['fp16_impact']:>10} | {s['bf16_impact']:>10} | {s['fp32_impact']:>10}")
    
    print("\n=== 判断标准 ===")
    
    criteria = [
        {"condition": "Max AE < 1e-3", "interpretation": "正常误差范围", "action": "无需处理"},
        {"condition": "Max AE 1e-3 ~ 1e-2", "interpretation": "轻微偏差", "action": "监控观察"},
        {"condition": "Max AE 1e-2 ~ 1e-1", "interpretation": "明显偏差", "action": "检查配置"},
        {"condition": "Max AE > 1e-1", "interpretation": "严重问题", "action": "立即排查"},
        {"condition": "出现NaN/Inf", "interpretation": "数值溢出", "action": "立即修复"},
    ]
    
    print(f"\n{'条件':<20} | {'解读':<20} | {'建议行动':<20}")
    print("-" * 65)
    
    for c in criteria:
        print(f"{c['condition']:<20} | {c['interpretation']:<20} | {c['action']:<20}")

自动回归测试

class RegressionTestRunner:
    """
    回归测试运行器
    
    每次代码变更后自动运行基准测试
    确保不引入新的误差
    """
    
    def __init__(self, baseline_path=".flash_attention_baseline"):
        self.baseline_path = baseline_path
        self.baseline = None
    
    def save_baseline(self, flash_attn_fn, test_cases):
        """保存基准"""
        
        print("=== 保存基准数据 ===")
        
        baseline = {}
        
        for case_name, (q, k, v) in test_cases.items():
            ref_out = FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v)
            flash_out = flash_attn_fn(q, k, v)
            
            baseline[case_name] = {
                "ref_output": ref_out,
                "flash_output": flash_out,
                "expected_diff": (ref_out - flash_out).abs().max().item()
            }
            
            print(f"  {case_name}: 预期误差 = {baseline[case_name]['expected_diff']:.2e}")
        
        self.baseline = baseline
        
        # 保存
        torch.save(baseline, self.baseline_path)
        print(f"\n✅ 基准数据已保存至 {self.baseline_path}")
    
    def run_regression(self, flash_attn_fn, test_cases):
        """运行回归测试"""
        
        print("\n=== 回归测试 ===")
        
        if self.baseline is None:
            if not os.path.exists(self.baseline_path):
                print("❌ 未找到基准数据,请先运行 save_baseline")
                return
            self.baseline = torch.load(self.baseline_path)
        
        all_passed = True
        
        for case_name, (q, k, v) in test_cases.items():
            if case_name not in self.baseline:
                print(f"  ⚠️ {case_name}: 无基准数据,跳过")
                continue
            
            ref_out = FlashAttentionCorrectnessBenchmark().reference_attention(q, k, v)
            flash_out = flash_attn_fn(q, k, v)
            
            current_diff = (ref_out - flash_out).abs().max().item()
            baseline_diff = self.baseline[case_name]["expected_diff"]
            
            # 允许10%的误差增长
            tolerance = baseline_diff * 1.1 + 1e-6
            
            passed = current_diff <= tolerance
            
            if passed:
                print(f"  ✅ {case_name}: {current_diff:.2e} (基准: {baseline_diff:.2e})")
            else:
                print(f"  ❌ {case_name}: {current_diff:.2e} (基准: {baseline_diff:.2e}) [超限]")
                all_passed = False
        
        if all_passed:
            print("\n✅ 回归测试全部通过")
        else:
            print("\n❌ 回归测试失败,请检查变更")
        
        return all_passed

总结:正确性基准配置清单

测试类型 验收标准 超限处理
基本正确性 Max AE < 1e-4 立即排查
数值精度 Max AE < 1e-2 监控
边界条件 全部通过 立即修复
溢出处理 无NaN/Inf 立即修复
长序列 Max AE < 1e-3 检查
极端值 全部有效 立即修复
回归测试 误差不增长>10% 审查变更

建议

  • 每次部署前运行完整测试套件
  • 每次代码变更前保存基准
  • 监控生产环境的数值精度

代码和文档:

https://atomgit.com/cann/ops-transformer

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐