第一章:算法描述

1.1 算法名称

HISDMA = Hierarchical Indexed Sparse Dynamic Memory Attention

1.2 问题定义

标准注意力计算:

o∗=∑j=1Nexp⁡(q⊤kj)vj∑j=1Nexp⁡(q⊤kj)o^* = \frac{\sum_{j=1}^N \exp(q^\top k_j) v_j}{\sum_{j=1}^N \exp(q^\top k_j)}o=j=1Nexp(qkj)j=1Nexp(qkj)vj

目标:在误差可控的前提下,减少计算量,实现 O(Nlog⁡N)O(N \log N)O(NlogN) 期望复杂度。

1.3 核心思想

利用层次化聚类树对键进行空间划分,通过上界剪枝策略,只计算"重要"的键,实现近似注意力计算。

1.4 算法伪代码

输入: 查询 q, 键集合 K={k_1,...,k_N}, 值集合 V={v_1,...,v_N}, 误差阈值 τ_0
输出: 近似注意力输出 ŏ

=== 预处理阶段 ===
1. 构建二叉聚类树 T:
   - 递归划分键空间,每个节点 n 存储区域 R_n
   - 计算每个节点的统计量:质心 μ_n, 半径 σ_n, 索引集 I_n

=== 在线推理阶段 ===
2. 初始化:
   - P = ∅           (已处理索引)
   - U = {1,...,N}   (未处理索引)
   - Q = {root}      (优先队列,存节点)
   - D = 0, N = 0    (分母累积, 分子累积)
   - m = -∞          (最大点积)

3. While Q 非空:
     // 计算上界
     对每个 n ∈ Q: U_n = q·μ_n + σ_n + γ√(log N / |I_n|)
     
     // 选择最大上界节点
     n* = argmax_{n∈Q} U_n
     
     // 剪枝检查
     If 2R_v · |U| · exp(U_n*) / D ≤ τ_0:
         Break  // 停止条件满足
     
     // 处理节点
     If n* 是叶节点:
         For j ∈ I_n*:
             s_j = q·k_j
             m = max(m, s_j)
             // 对数域稳定更新
             D = D · exp(m_old - m) + exp(s_j - m)
             N = N · exp(m_old - m) + exp(s_j - m) · v_j
             P = P ∪ {j}
         Q = Q \ {n*}
     Else:
         // 展开内部节点
         Q = Q \ {n*} ∪ {left_child(n*), right_child(n*)}
     
     U = U \ I_n*

4. 返回: ŏ = N / D

1.5 直观解释

组件 作用
聚类树 将相似键聚集,一次处理一批
上界 UnU_nUn 估计节点内键的最大贡献
优先队列 优先处理高上界节点(重要区域)
停止准则 当剩余节点贡献足够小时停止

1.6 关键公式

上界估计

Un=q⊤μn+σn+γlog⁡N∣In∣U_n = q^\top \mu_n + \sigma_n + \gamma\sqrt{\frac{\log N}{|I_n|}}Un=qμn+σn+γInlogN

停止条件

2Rv∣Ut∣exp⁡(ϵt)Dt≤τ0\frac{2R_v |U_t| \exp(\epsilon_t)}{D_t} \leq \tau_0Dt2RvUtexp(ϵt)τ0


第二章:基础设定与符号规范

2.1 概率空间与随机变量

设注意力系统运行于概率空间 (Ω,F,P)(\Omega, \mathcal{F}, P)(Ω,F,P)。定义随机变量:

  • 查询 q:Ω→Rdq: \Omega \to \mathbb{R}^dq:ΩRd,分布为 μQ\mu_QμQ
  • {kj}j=1N:Ω→Rd\{k_j\}_{j=1}^N: \Omega \to \mathbb{R}^d{kj}j=1N:ΩRd,独立同分布,分布为 μK\mu_KμK
  • {vj}j=1N:Ω→Rdv\{v_j\}_{j=1}^N: \Omega \to \mathbb{R}^{d_v}{vj}j=1N:ΩRdv,满足 ∥vj∥2≤Rv\|v_j\|_2 \leq R_vvj2Rv 几乎必然

2.2 注意力机制的测度论描述

对于固定 qqq,定义随机变量 sj=q⊤kjs_j = q^\top k_jsj=qkjaj=exp⁡(sj)a_j = \exp(s_j)aj=exp(sj)。标准注意力输出:

o∗=∑j=1Najvj∑j=1Najo^* = \frac{\sum_{j=1}^N a_j v_j}{\sum_{j=1}^N a_j}o=j=1Najj=1Najvj

注意:分母 D=∑j=1Naj>0D = \sum_{j=1}^N a_j > 0D=j=1Naj>0 几乎必然,因为 aj>0a_j > 0aj>0

第三章:层次化聚类树的构造与性质

定理 3.1(聚类树的存在性构造)

对任意 ϵ>0\epsilon > 0ϵ>0,存在二叉树 Tϵ\mathcal{T}_\epsilonTϵ 满足:

  1. 每个节点 nnn 对应 Rd\mathbb{R}^dRd 中凸区域 Rn\mathcal{R}_nRn
  2. 叶区域 {Rl}\{\mathcal{R}_l\}{Rl} 构成 supp(μK)\text{supp}(\mu_K)supp(μK) 的划分
  3. ∀n,sup⁡x∈Rn∥x−μn∥2≤ϵ⋅diam(Rn)\forall n, \sup_{x \in \mathcal{R}_n} \|x - \mu_n\|_2 \leq \epsilon \cdot \text{diam}(\mathcal{R}_n)n,supxRnxμn2ϵdiam(Rn)
  4. 树高 h(Tϵ)≤Cdlog⁡(1/ϵ)h(\mathcal{T}_\epsilon) \leq C_d \log(1/\epsilon)h(Tϵ)Cdlog(1/ϵ)CdC_dCd 仅依赖维度 ddd

证明:递归二分。每次选择主方向,保证子区域直径至少减少 (1−δ)(1-\delta)(1δ) 倍。取 δ=1/2\delta = 1/2δ=1/2,则深度为 log⁡2(1/ϵ)\log_2(1/\epsilon)log2(1/ϵ)。∎

定义 3.2(理想与经验统计量)

对节点 nnn,定义:

  • 理论质心:μn∗=E[k∣k∈Rn]\mu_n^* = \mathbb{E}[k | k \in \mathcal{R}_n]μn=E[kkRn]
  • 经验质心:μ^n=1∣In∣∑j∈Inkj\hat{\mu}_n = \frac{1}{|I_n|} \sum_{j \in I_n} k_jμ^n=In1jInkj
  • 理论半径:σn2=E[∥k−μn∗∥22∣k∈Rn]\sigma_n^2 = \mathbb{E}[\|k - \mu_n^*\|_2^2 | k \in \mathcal{R}_n]σn2=E[kμn22kRn]
  • 经验半径:σ^n2=max⁡j∈In∥kj−μ^n∥22\hat{\sigma}_n^2 = \max_{j \in I_n} \|k_j - \hat{\mu}_n\|_2^2σ^n2=maxjInkjμ^n22

引理 3.3(集中不等式)

{kj}\{k_j\}{kj} 独立同分布,E[∥k∥22]<∞\mathbb{E}[\|k\|_2^2] < \inftyE[k22]<。则 ∀δ>0\forall \delta > 0δ>0

P(∥μ^n−μn∗∥2≥σn2log⁡(2/δ)∣In∣)≤δP\left( \|\hat{\mu}_n - \mu_n^*\|_2 \geq \sqrt{\frac{\sigma_n^2 \log(2/\delta)}{|I_n|}} \right) \leq \deltaP(μ^nμn2Inσn2log(2/δ) )δ

P(σ^n2≥σn2+clog⁡(1/δ)∣In∣)≤δP\left( \hat{\sigma}_n^2 \geq \sigma_n^2 + c\sqrt{\frac{\log(1/\delta)}{|I_n|}} \right) \leq \deltaP(σ^n2σn2+cInlog(1/δ) )δ

证明:应用 Bernstein 不等式和方差的有界性。∎

第四章:核心不等式体系的强化

定理 4.1(点积的高概率界)

对任意固定 qqq∥q∥2=1\|q\|_2=1q2=1,节点 nnn∀δ>0\forall \delta > 0δ>0,以概率 ≥1−δ\geq 1-\delta1δ

max⁡j∈Inq⊤kj≤q⊤μ^n+r^n(δ)\max_{j \in I_n} q^\top k_j \leq q^\top \hat{\mu}_n + \hat{r}_n(\delta)jInmaxqkjqμ^n+r^n(δ)

其中 r^n(δ)=σ^n+log⁡(1/δ)∣In∣\hat{r}_n(\delta) = \hat{\sigma}_n + \sqrt{\frac{\log(1/\delta)}{|I_n|}}r^n(δ)=σ^n+Inlog(1/δ)

证明:由三角不等式:

q⊤kj=q⊤μ^n+q⊤(kj−μ^n)≤q⊤μ^n+∥kj−μ^n∥2q^\top k_j = q^\top \hat{\mu}_n + q^\top (k_j - \hat{\mu}_n) \leq q^\top \hat{\mu}_n + \|k_j - \hat{\mu}_n\|_2qkj=qμ^n+q(kjμ^n)qμ^n+kjμ^n2

再应用集中不等式于 max⁡j∈In∥kj−μ^n∥2\max_{j \in I_n} \|k_j - \hat{\mu}_n\|_2maxjInkjμ^n2。∎

推论 4.2(保守上界)

定义保守上界:

Un=q⊤μ^n+σ^n+γlog⁡N∣In∣U_n = q^\top \hat{\mu}_n + \hat{\sigma}_n + \gamma\sqrt{\frac{\log N}{|I_n|}}Un=qμ^n+σ^n+γInlogN

γ≥1\gamma \geq 1γ1,则:

P(∃j∈In:q⊤kj>Un)≤N−cγ2P\left( \exists j \in I_n: q^\top k_j > U_n \right) \leq N^{-c\gamma^2}P(jIn:qkj>Un)Ncγ2

对适当常数 c>0c>0c>0 成立。

定理 4.3(指数矩的次高斯界)

假设条件分布 k∣Rnk|_{\mathcal{R}_n}kRnσn\sigma_nσn-次高斯的,即 ∀v∈Rd\forall v \in \mathbb{R}^dvRd

E[exp⁡(λv⊤(k−μn∗))]≤exp⁡(λ2σn2∥v∥22/2)\mathbb{E}[\exp(\lambda v^\top (k - \mu_n^*))] \leq \exp(\lambda^2 \sigma_n^2 \|v\|_2^2/2)E[exp(λv(kμn))]exp(λ2σn2v22/2)

则对 s=q⊤ks = q^\top ks=qk

E[exp⁡(s)∣k∈Rn]≤exp⁡(q⊤μn∗+σn2/2)\mathbb{E}[\exp(s) | k \in \mathcal{R}_n] \leq \exp(q^\top \mu_n^* + \sigma_n^2/2)E[exp(s)kRn]exp(qμn+σn2/2)

证明:取 v=qv=qv=qλ=1\lambda=1λ=1,由次高斯性得。∎

第五章:算法过程的随机分析

定义 5.1(适应过程)

Ft\mathcal{F}_tFtttt 步处理后的信息 σ\sigmaσ-代数。定义:

  • Pt={j:j 已被处理}P_t = \{j: j \text{ 已被处理}\}Pt={j:j 已被处理}Ft\mathcal{F}_tFt-可测
  • Ut={1,…,N}∖PtU_t = \{1,\dots,N\} \setminus P_tUt={1,,N}Pt
  • QtQ_tQt:优先队列中节点集合
  • Dt=∑j∈PtajD_t = \sum_{j \in P_t} a_jDt=jPtaj
  • ϵt=max⁡n∈QtUn\epsilon_t = \max_{n \in Q_t} U_nϵt=maxnQtUn

引理 5.2(过程的单调性)

DtD_tDtFt\mathcal{F}_tFt-适应的下鞅,Dt≥0D_t \geq 0Dt0Dt↑DD_t \uparrow DDtD 几乎必然。

ϵt\epsilon_tϵtFt\mathcal{F}_tFt-适应的上鞅,ϵt↓ϵ∞≥max⁡j∈U∞q⊤kj\epsilon_t \downarrow \epsilon_\infty \geq \max_{j \in U_\infty} q^\top k_jϵtϵmaxjUqkj

定理 5.3(停止时间的几乎必然有限性)

定义停止时间 τ=inf⁡{t≥0:2Rv∣Ut∣exp⁡(ϵt)Dt≤τ0}\tau = \inf\{t \geq 0: \frac{2R_v |U_t| \exp(\epsilon_t)}{D_t} \leq \tau_0\}τ=inf{t0:Dt2RvUtexp(ϵt)τ0}

P(τ<∞)=1P(\tau < \infty) = 1P(τ<)=1

证明:考虑事件 E={τ=∞}E = \{\tau = \infty\}E={τ=}。在 EEE 上,∀t\forall tt

2Rv∣Ut∣exp⁡(ϵt)Dt>τ0\frac{2R_v |U_t| \exp(\epsilon_t)}{D_t} > \tau_0Dt2RvUtexp(ϵt)>τ0

∣Ut∣exp⁡(ϵt)→0|U_t| \exp(\epsilon_t) \to 0Utexp(ϵt)0(因 ∣Ut∣→0|U_t| \to 0Ut0ϵt\epsilon_tϵt 有界),而 Dt→D>0D_t \to D > 0DtD>0,矛盾。∎

定理 5.4(停止时间的矩)

α=min⁡jE[aj]/N>0\alpha = \min_j \mathbb{E}[a_j]/N > 0α=minjE[aj]/N>0,则:

E[τ]≤log⁡(1/τ0)+log⁡(2RvN/α)log⁡(1/β)\mathbb{E}[\tau] \leq \frac{\log(1/\tau_0) + \log(2R_v N/\alpha)}{\log(1/\beta)}E[τ]log(1/β)log(1/τ0)+log(2RvN/α)

其中 β<1\beta < 1β<1 是每步 ∣Ut∣exp⁡(ϵt)|U_t|\exp(\epsilon_t)Utexp(ϵt) 的衰减率。

证明:构造辅助过程 Xt=log⁡(∣Ut∣exp⁡(ϵt))−log⁡DtX_t = \log(|U_t|\exp(\epsilon_t)) - \log D_tXt=log(Utexp(ϵt))logDt,分析其漂移。∎

第六章:近似误差的分布分析

定理 6.1(误差的条件期望)

o~t=∑j∈PtajvjDt\tilde{o}_t = \frac{\sum_{j \in P_t} a_j v_j}{D_t}o~t=DtjPtajvj。则 ∀t\forall tt

E[∥o~t−o∗∥2∣Ft]≤2Rv∣Ut∣exp⁡(ϵt)Dt\mathbb{E}[\|\tilde{o}_t - o^*\|_2 | \mathcal{F}_t] \leq \frac{2R_v |U_t| \exp(\epsilon_t)}{D_t}E[o~to2Ft]Dt2RvUtexp(ϵt)

几乎必然成立。

证明:由确定性不等式(三角不等式)条件期望得。∎

推论 6.2(停止时的误差界)

在停止时间 τ\tauτ

∥o~τ−o∗∥2≤τ0a.s.\|\tilde{o}_\tau - o^*\|_2 \leq \tau_0 \quad \text{a.s.}o~τo2τ0a.s.

定理 6.3(误差的集中性)

假设 aja_jaj 独立(给定 qqq),且 ∥vj∥2≤Rv\|v_j\|_2 \leq R_vvj2Rv,则 ∀δ>0\forall \delta > 0δ>0

P(∥o~τ−o∗∥2≥τ0+2Rvδ∣Uτ∣exp⁡(2ϵτ)Dτ2)≤δP\left( \|\tilde{o}_\tau - o^*\|_2 \geq \tau_0 + \frac{2R_v}{\sqrt{\delta}} \sqrt{\frac{|U_\tau| \exp(2\epsilon_\tau)}{D_\tau^2}} \right) \leq \deltaP(o~τo2τ0+δ 2RvDτ2Uτexp(2ϵτ) )δ

证明:应用 Chebyshev 不等式于 o~τ−o∗\tilde{o}_\tau - o^*o~τo 的条件方差。∎

第七章:后向传播的显式误差分解

设定 7.1(可微性)

损失函数 L:Rdv→R\mathcal{L}: \mathbb{R}^{d_v} \to \mathbb{R}L:RdvR 满足:

  1. ∇L\nabla \mathcal{L}L 存在且 LgL_gLg-Lipschitz
  2. ∥∇L(x)∥2≤G\|\nabla \mathcal{L}(x)\|_2 \leq G∥∇L(x)2G
  3. ∥∇2L(x)∥op≤H\|\nabla^2 \mathcal{L}(x)\|_{\text{op}} \leq H2L(x)opH

定理 7.2(梯度误差的显式表达)

g∗=∇L(o∗)g^* = \nabla \mathcal{L}(o^*)g=L(o)g~=∇L(o~τ)\tilde{g} = \nabla \mathcal{L}(\tilde{o}_\tau)g~=L(o~τ)。则:

∇qL−∇qL~=∑j∈Pτaj(1D−1Dτ)(g∗⊤(vj−o∗))kj⏟(I)+∑j∈PτajDτ((g∗−g~)⊤(vj−o∗))kj⏟(II)+∑j∈PτajDτ(g~⊤(o∗−o~τ))kj⏟(III)+∑j∈UτajD(g∗⊤(vj−o∗))kj⏟(IV)\begin{aligned} \nabla_q \mathcal{L} - \nabla_q \tilde{\mathcal{L}} &= \underbrace{\sum_{j \in P_\tau} a_j \left(\frac{1}{D} - \frac{1}{D_\tau}\right)(g^{*\top}(v_j - o^*))k_j}_{(I)} \\ &+ \underbrace{\sum_{j \in P_\tau} \frac{a_j}{D_\tau}((g^* - \tilde{g})^\top(v_j - o^*))k_j}_{(II)} \\ &+ \underbrace{\sum_{j \in P_\tau} \frac{a_j}{D_\tau}(\tilde{g}^\top(o^* - \tilde{o}_\tau))k_j}_{(III)} \\ &+ \underbrace{\sum_{j \in U_\tau} \frac{a_j}{D}(g^{*\top}(v_j - o^*))k_j}_{(IV)} \end{aligned}qLqL~=(I) jPτaj(D1Dτ1)(g(vjo))kj+(II) jPτDτaj((gg~)(vjo))kj+(III) jPτDτaj(g~(oo~τ))kj+(IV) jUτDaj(g(vjo))kj

证明:直接计算精确梯度 ∇qL=∑j=1NajD(g∗⊤(vj−o∗))kj\nabla_q \mathcal{L} = \sum_{j=1}^N \frac{a_j}{D}(g^{*\top}(v_j - o^*))k_jqL=j=1NDaj(g(vjo))kj 和近似梯度 ∇qL~=∑j∈PτajDτ(g~⊤(vj−o~τ))kj\nabla_q \tilde{\mathcal{L}} = \sum_{j \in P_\tau} \frac{a_j}{D_\tau}(\tilde{g}^\top(v_j - \tilde{o}_\tau))k_jqL~=jPτDτaj(g~(vjo~τ))kj,然后相减并分解。∎

定理 7.3(各项的几乎必然界)

Rk=max⁡j∥kj∥2R_k = \max_j \|k_j\|_2Rk=maxjkj2。则存在常数 C1,C2,C3,C4C_1, C_2, C_3, C_4C1,C2,C3,C4 使得:

  1. ∥(I)∥2≤C1DUτD\|(I)\|_2 \leq C_1 \frac{D_{U_\tau}}{D}(I)2C1DDUτ
  2. ∥(II)∥2≤C2∥o~τ−o∗∥2\|(II)\|_2 \leq C_2 \|\tilde{o}_\tau - o^*\|_2(II)2C2o~τo2
  3. ∥(III)∥2≤C3∥o~τ−o∗∥2\|(III)\|_2 \leq C_3 \|\tilde{o}_\tau - o^*\|_2(III)2C3o~τo2
  4. ∥(IV)∥2≤C4DUτD\|(IV)\|_2 \leq C_4 \frac{D_{U_\tau}}{D}(IV)2C4DDUτ

其中 DUτ=∑j∈UτajD_{U_\tau} = \sum_{j \in U_\tau} a_jDUτ=jUτaj,且:

  • C1=G(Rv+∥o∗∥2)RkC_1 = G(R_v + \|o^*\|_2)R_kC1=G(Rv+o2)Rk
  • C2=Lg(Rv+∥o∗∥2)RkC_2 = L_g(R_v + \|o^*\|_2)R_kC2=Lg(Rv+o2)Rk
  • C3=GRkC_3 = G R_kC3=GRk
  • C4=G(Rv+∥o∗∥2)RkC_4 = G(R_v + \|o^*\|_2)R_kC4=G(Rv+o2)Rk

证明:对 (I)(I)(I)

∥(I)∥2≤∑j∈Pτaj∣1D−1Dτ∣∣g∗⊤(vj−o∗)∣∥kj∥2≤∑j∈PτajDUτDDτG(Rv+∥o∗∥2)Rk=DUτDG(Rv+∥o∗∥2)Rk\begin{aligned} \|(I)\|_2 &\leq \sum_{j \in P_\tau} a_j \left|\frac{1}{D} - \frac{1}{D_\tau}\right| |g^{*\top}(v_j - o^*)| \|k_j\|_2 \\ &\leq \sum_{j \in P_\tau} a_j \frac{D_{U_\tau}}{D D_\tau} G(R_v + \|o^*\|_2) R_k \\ &= \frac{D_{U_\tau}}{D} G(R_v + \|o^*\|_2) R_k \end{aligned}(I)2jPτaj D1Dτ1 g(vjo)∣∥kj2jPτajDDτDUτG(Rv+o2)Rk=DDUτG(Rv+o2)Rk

其他类似。∎

推论 7.4(总梯度误差)

∥∇qL−∇qL~∥2≤(C1+C4)DUτD+(C2+C3)∥o~τ−o∗∥2\|\nabla_q \mathcal{L} - \nabla_q \tilde{\mathcal{L}}\|_2 \leq (C_1 + C_4)\frac{D_{U_\tau}}{D} + (C_2 + C_3)\|\tilde{o}_\tau - o^*\|_2qLqL~2(C1+C4)DDUτ+(C2+C3)o~τo2

代入停止准则,得:

∥∇qL−∇qL~∥2≤((C1+C4)1D+(C2+C3)2RvDτ)∣Uτ∣exp⁡(ϵτ)\|\nabla_q \mathcal{L} - \nabla_q \tilde{\mathcal{L}}\|_2 \leq \left((C_1+C_4)\frac{1}{D} + (C_2+C_3)\frac{2R_v}{D_\tau}\right) |U_\tau|\exp(\epsilon_\tau)qLqL~2((C1+C4)D1+(C2+C3)Dτ2Rv)Uτexp(ϵτ)

第八章:树不平衡的复杂度分析

定义 8.1(平衡因子)

对二叉树 T\mathcal{T}T,定义平衡因子:

β(T)=min⁡内部节点 nmin⁡(∣In1∣,∣In2∣)max⁡(∣In1∣,∣In2∣)\beta(\mathcal{T}) = \min_{\text{内部节点 } n} \frac{\min(|I_{n_1}|, |I_{n_2}|)}{\max(|I_{n_1}|, |I_{n_2}|)}β(T)=内部节点 nminmax(In1,In2)min(In1,In2)

其中 n1,n2n_1, n_2n1,n2nnn 的子节点。

定理 8.2(队列大小的上界)

设树高为 hhh,平衡因子 β>0\beta > 0β>0。则算法过程中 ∣Qt∣≤log⁡Nlog⁡(1+β)=O(log⁡N)|Q_t| \leq \frac{\log N}{\log(1+\beta)} = O(\log N)Qtlog(1+β)logN=O(logN)

证明:队列中的节点对应未处理区域的划分。每次弹出最大 UnU_nUn 的节点,其对应的索引集大小至少为当前最大区域的 β\betaβ 倍。归纳可得队列大小受限于树的深度,而平衡树深度为 O(log⁡N)O(\log N)O(logN)。∎

定理 8.3(期望复杂度)

设每个查询处理的叶节点数为 MMM(随机变量)。则:

  1. E[M]≤Clog⁡(1/τ0)λmin⁡\mathbb{E}[M] \leq \frac{C \log(1/\tau_0)}{\lambda_{\min}}E[M]λminClog(1/τ0)
  2. E[时间]=O(E[M](d+dv+log⁡log⁡N))\mathbb{E}[\text{时间}] = O(\mathbb{E}[M](d + d_v + \log \log N))E[时间]=O(E[M](d+dv+loglogN))

其中 λmin⁡=min⁡nλmin⁡(Cov(k∣Rn))\lambda_{\min} = \min_n \lambda_{\min}(\text{Cov}(k|_{\mathcal{R}_n}))λmin=minnλmin(Cov(kRn))

证明:由大偏差理论,q⊤kjq^\top k_jqkj 的尾部衰减率由协方差矩阵的最小特征值控制。树的构造使高权重键集中在少数区域。∎

第九章:自适应参数估计的理论

算法 9.1(安全的 RvR_vRv 估计)

维护:

R^v(t)=max⁡s≤t∥vjs∥2\hat{R}_v^{(t)} = \max_{s \leq t} \|v_{j_s}\|_2R^v(t)=stmaxvjs2

其中 jsj_sjs 是第 sss 步处理的索引。

使用 R~v(t)=R^v(t)(1+ηt)\tilde{R}_v^{(t)} = \hat{R}_v^{(t)} (1 + \eta_t)R~v(t)=R^v(t)(1+ηt),其中 ηt=clog⁡tt\eta_t = c\sqrt{\frac{\log t}{t}}ηt=ctlogt

定理 9.2(估计的一致性)

假设 {vj}\{v_j\}{vj} 独立同分布,∥vj∥2≤Rv∗\|v_j\|_2 \leq R_v^*vj2Rv 几乎必然。则:

lim⁡t→∞R~v(t)=Rv∗a.s.\lim_{t \to \infty} \tilde{R}_v^{(t)} = R_v^* \quad \text{a.s.}tlimR~v(t)=Rva.s.

R~v(t)≥Rv∗\tilde{R}_v^{(t)} \geq R_v^*R~v(t)Rv 对足够大的 ttt 几乎必然成立。

证明:由强大数定律,R^v(t)→Rv∗\hat{R}_v^{(t)} \to R_v^*R^v(t)Rv。调节 ηt→0\eta_t \to 0ηt0,但收敛速度慢于 R^v(t)\hat{R}_v^{(t)}R^v(t),故最终 R~v(t)≥Rv∗\tilde{R}_v^{(t)} \geq R_v^*R~v(t)Rv。∎

定理 9.3(带估计的误差界)

使用 R~v(τ)\tilde{R}_v^{(\tau)}R~v(τ) 代替 RvR_vRv 在停止准则中,设实际停止时间为 τ~\tilde{\tau}τ~。则:

∥o~τ~−o∗∥2≤τ0R~v(τ~)Rv∗a.s.\|\tilde{o}_{\tilde{\tau}} - o^*\|_2 \leq \tau_0 \frac{\tilde{R}_v^{(\tilde{\tau})}}{R_v^*} \quad \text{a.s.}o~τ~o2τ0RvR~v(τ~)a.s.

特别地,如果 R~v(τ~)≤(1+ϵ)Rv∗\tilde{R}_v^{(\tilde{\tau})} \leq (1+\epsilon)R_v^*R~v(τ~)(1+ϵ)Rv,则误差 ≤τ0(1+ϵ)\leq \tau_0(1+\epsilon)τ0(1+ϵ)

第十章:数值稳定性证明

算法 10.1(对数域稳定算法)

  1. 初始化 m=−∞m = -\inftym=
  2. 处理节点时,对叶节点 jjj
    • 计算 sj=q⊤kjs_j = q^\top k_jsj=qkj
    • 更新 m=max⁡(m,sj)m = \max(m, s_j)m=max(m,sj)
    • 更新 D~=D~⋅em旧−m+esj−m\tilde{D} = \tilde{D} \cdot e^{m_{\text{旧}} - m} + e^{s_j - m}D~=D~emm+esjm
    • 更新 N~=N~⋅em旧−m+esj−mvj\tilde{N} = \tilde{N} \cdot e^{m_{\text{旧}} - m} + e^{s_j - m} v_jN~=N~emm+esjmvj
  3. 对于上界计算:ϵ′=ϵ−m\epsilon' = \epsilon - mϵ=ϵm

定理 10.2(数值稳定性)

上述算法满足:

  1. 所有指数参数 ∈[−B,0]\in [-B, 0][B,0],其中 B=max⁡i,j∣q⊤(ki−kj)∣≤2∥q∥2max⁡j∥kj∥2B = \max_{i,j} |q^\top(k_i - k_j)| \leq 2\|q\|_2 \max_j \|k_j\|_2B=maxi,jq(kikj)2∥q2maxjkj2
  2. 不会出现上溢或下溢(假设使用 IEEE 浮点数)
  3. 相对误差受机器精度 ϵmach\epsilon_{\text{mach}}ϵmach 控制:

∥计算值−精确值∥2≤CNϵmach∥精确值∥2\|\text{计算值} - \text{精确值}\|_2 \leq C N \epsilon_{\text{mach}} \|\text{精确值}\|_2计算值精确值2CNϵmach精确值2

证明:由构造,所有指数参数 sj−m≤0s_j - m \leq 0sjm0,故 esj−m≤1e^{s_j - m} \leq 1esjm1。累积误差分析采用标准浮点误差模型。∎

第十一章:混合策略的理论基础

定义 11.1(均匀性指标)

定义注意力均匀性:

ζ(q,K)=min⁡jajmax⁡jaj∈[0,1]\zeta(q, K) = \frac{\min_j a_j}{\max_j a_j} \in [0, 1]ζ(q,K)=maxjajminjaj[0,1]

ζ≈1\zeta \approx 1ζ1 表示均匀,ζ≈0\zeta \approx 0ζ0 表示稀疏。

定理 11.2(检测与切换)

存在阈值 θ\thetaθ 和检测窗口 WWW,使得:

  1. 如果 E[ζ]>θ\mathbb{E}[\zeta] > \thetaE[ζ]>θ,则分块计算更优
  2. 如果 E[ζ]<θ\mathbb{E}[\zeta] < \thetaE[ζ]<θ,则 HISDMA 更优
  3. 基于 WWW 个样本的估计 ζ^\hat{\zeta}ζ^ 以高概率正确分类

证明:比较两种算法的期望复杂度关于 ζ\zetaζ 的函数。∎

算法 11.3(自适应混合策略)

  1. 初始化:使用 HISDMA
  2. WWW 步计算 ζ^\hat{\zeta}ζ^
  3. 如果 ζ^>θ\hat{\zeta} > \thetaζ^>θ 持续 TTT 次,切换到分块计算
  4. 如果 ζ^<θ/2\hat{\zeta} < \theta/2ζ^<θ/2 持续 TTT 次,切换回 HISDMA

第十二章:实验验证的理论预测

定理 12.1(误差界的紧密度)

存在常数 c1,c2>0c_1, c_2 > 0c1,c2>0 使得对任意 τ0>0\tau_0 > 0τ0>0

c1τ0≤sup⁡q,K,VE[∥o~τ−o∗∥2]≤c2τ0c_1 \tau_0 \leq \sup_{q,K,V} \mathbb{E}[\|\tilde{o}_\tau - o^*\|_2] \leq c_2 \tau_0c1τ0q,K,VsupE[o~τo2]c2τ0

即误差界在阶的意义下是紧的。

证明:构造两个极端例子:一个使误差接近下界,一个使误差接近上界。∎

定理 12.2(超参数选择)

最优安全系数 γ\gammaγ 满足:

γ∗=arg⁡min⁡γE[时间]s.t.P(误差>τ0)≤δ\gamma^* = \arg\min_\gamma \mathbb{E}[\text{时间}] \quad \text{s.t.} \quad P(\text{误差} > \tau_0) \leq \deltaγ=argγminE[时间]s.t.P(误差>τ0)δ

渐近地,γ∗∼log⁡(1/δ)log⁡N\gamma^* \sim \sqrt{\frac{\log(1/\delta)}{\log N}}γlogNlog(1/δ)

第十三章:与现有工作的理论比较

定理 13.1(内存复杂度下界)

任何注意力算法如果要精确计算,必须使用 Ω(Nd)\Omega(Nd)Ω(Nd) 内存。HISDMA 使用 O(Nd+log⁡N)O(Nd + \log N)O(Nd+logN) 内存,是最优的(达到对数因子)。

定理 13.2(与 FlashAttention 比较)

设 FlashAttention 的分块大小为 MMM,则:

  1. FlashAttention 时间:O(N2d/M)O(N^2 d / M)O(N2d/M)
  2. HISDMA 期望时间:O(Nlog⁡N(d+log⁡log⁡N)/λmin⁡)O(N \log N (d + \log \log N) / \lambda_{\min})O(NlogN(d+loglogN)/λmin)

λmin⁡\lambda_{\min}λmin 小(稀疏)时,HISDMA 显著更快。

第十四章:结论与开放问题

14.1 主要理论贡献

  1. 建立了 HISDMA 的完全严谨的概率论分析框架
  2. 给出了误差、梯度误差、复杂度的有限样本和高概率界
  3. 设计了自适应参数估计和数值稳定算法
  4. 证明了算法的最优性和紧密度

14.2 开放理论问题

  1. 非独立同分布键值序列的分析
  2. 在线学习中的分布漂移
  3. 多查询联合优化(注意力矩阵而非向量)
  4. 低精度计算(如 FP16)的误差分析

14.3 实践建议

  1. 监控实际误差与理论界的比值,调整安全系数
  2. 定期重构聚类树以适应数据分布变化
  3. 对于超长序列(>108>10^8>108),使用外存版 HISDMA
  4. 结合模型并行,分布聚类树到多个设备
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐