多项式快速插值完整详解

一、问题与拉格朗日插值公式

给定 nnn 个点 (xi,yi)(x_i, y_i)(xi,yi),求 n−1n-1n1 次多项式 f(x)f(x)f(x) 满足 f(xi)=yif(x_i) = y_if(xi)=yi

拉格朗日插值公式:
f(x)=∑i=1nyi∏j≠ix−xjxi−xjf(x) = \sum_{i=1}^n y_i \prod_{j \neq i} \frac{x - x_j}{x_i - x_j}f(x)=i=1nyij=ixixjxxj

直接计算是 O(n2)O(n^2)O(n2) 的,需要优化。

二、快速插值算法框架

2.1 基本思路

设:

  • M(x)=∏i=1n(x−xi)M(x) = \prod_{i=1}^n (x - x_i)M(x)=i=1n(xxi)
  • Mi(x)=M(x)x−xi=∏j≠i(x−xj)M_i(x) = \frac{M(x)}{x - x_i} = \prod_{j \neq i} (x - x_j)Mi(x)=xxiM(x)=j=i(xxj)
  • wi=yi∏j≠i(xi−xj)=yiM′(xi)w_i = \frac{y_i}{\prod_{j \neq i} (x_i - x_j)} = \frac{y_i}{M'(x_i)}wi=j=i(xixj)yi=M(xi)yi

则:
f(x)=∑i=1nwiMi(x)f(x) = \sum_{i=1}^n w_i M_i(x)f(x)=i=1nwiMi(x)

算法步骤:

  1. 计算 M(x)=∏i=1n(x−xi)M(x) = \prod_{i=1}^n (x - x_i)M(x)=i=1n(xxi)
  2. 计算 M′(x)M'(x)M(x)
  3. 多点求值计算 M′(xi)M'(x_i)M(xi)
  4. 计算 wi=yi⋅[M′(xi)]−1mod  998244353w_i = y_i \cdot [M'(x_i)]^{-1} \mod 998244353wi=yi[M(xi)]1mod998244353
  5. 计算 f(x)=∑wiMi(x)f(x) = \sum w_i M_i(x)f(x)=wiMi(x)

步骤5用分治:
fl,r(x)=∑i=lrwi∏j=l,j≠ir(x−xj)f_{l,r}(x) = \sum_{i=l}^r w_i \prod_{j=l, j\neq i}^r (x - x_j)fl,r(x)=i=lrwij=l,j=ir(xxj)
则:
fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,midf_{l,r} = f_{l,mid} \cdot M_{mid+1,r} + f_{mid+1,r} \cdot M_{l,mid}fl,r=fl,midMmid+1,r+fmid+1,rMl,mid
其中 Ml,r=∏i=lr(x−xi)M_{l,r} = \prod_{i=l}^r (x - x_i)Ml,r=i=lr(xxi)

三、基础数学知识

3.1 模运算与逆元

模数 p=998244353p = 998244353p=998244353 是质数,费马小定理:ap−1≡1(modp)a^{p-1} \equiv 1 \pmod{p}ap11(modp)
逆元:a−1≡ap−2(modp)a^{-1} \equiv a^{p-2} \pmod{p}a1ap2(modp)

代码中的快速幂:

inline long long pow2(long long a1, long long b1) {
    long long c1 = 1;
    while (b1 != 0) {
        if (b1 % 2 == 1) c1 = c1 * a1 % mod;
        a1 = a1 * a1 % mod;
        b1 /= 2;
    }
    return c1;
}

3.2 Barrett Reduction

普通取模用 % 较慢,Barrett Reduction 用乘法和移位代替除法:
a mod p=a−⌊a/p⌋×pa \bmod p = a - \lfloor a/p \rfloor \times pamodp=aa/p×p
近似计算 ⌊a/p⌋\lfloor a/p \rfloora/p
⌊a/p⌋≈⌊(a×m)/264⌋,m=⌊264/p⌋\lfloor a/p \rfloor \approx \lfloor (a \times m) / 2^{64} \rfloor, \quad m = \lfloor 2^{64}/p \rfloora/p⌊(a×m)/264,m=264/p

代码实现:

const long long rr = (__int128)((__int128)1 << 64) / mod;
inline int mo(long long a1) {
    a1 -= ((__int128)a1 * rr >> 64) * mod;
    return a1 >= mod ? a1 - mod : a1;
}

四、快速数论变换(NTT)

4.1 原根与单位根

p=998244353=119×223+1p = 998244353 = 119 \times 2^{23} + 1p=998244353=119×223+1,原根 g=3g = 3g=3
nnn 次单位根:ωn=g(p−1)/n\omega_n = g^{(p-1)/n}ωn=g(p1)/n

NTT 变换:
A(k)=∑j=0n−1ajωnjk(k=0,1,…,n−1)A(k) = \sum_{j=0}^{n-1} a_j \omega_n^{jk} \quad (k=0,1,\dots,n-1)A(k)=j=0n1ajωnjk(k=0,1,,n1)

逆变换:
aj=1n∑k=0n−1A(k)ωn−jka_j = \frac{1}{n} \sum_{k=0}^{n-1} A(k) \omega_n^{-jk}aj=n1k=0n1A(k)ωnjk

4.2 递归NTT实现

代码使用递归实现,基于 Cooley-Tukey 算法:

inline void ntt(long long n1, long long *a1, long long b1) {
    if (n1 <= 1) return;
    // 奇偶分裂
    long long ax[n1/2], ay[n1/2];
    for (int i = 0; i <= n1-1; i += 2) {
        ax[i/2] = a1[i];
        ay[i/2] = a1[i+1];
    }
    ntt(n1/2, ax, b1);
    ntt(n1/2, ay, b1);
    
    // 合并
    long long ww, w = 1;
    if (b1 == 1) ww = pow2(3ll, (mod-1) / n1);  // 正变换
    else ww = pow2(pow2(3ll, (mod-1) / n1), mod-2);  // 逆变换
    
    for (int i = 0; i <= n1/2-1; i++, w = w * ww % mod) {
        a1[i] = (ax[i] + w * ay[i]) % mod;
        a1[i + n1/2] = (ax[i] - w * ay[i] + moc) % mod;
    }
}

其中 moc = 1ll*mod*mod 用于防止负数。

4.3 优化:预处理旋转因子

uv[n1][0] 存储 ωn1\omega_{n1}ωn1uv[n1][1] 存储 ωn1−1\omega_{n1}^{-1}ωn11,避免重复计算。

五、多项式基本操作

5.1 多项式乘法

NTT 将多项式乘法 O(n2)O(n^2)O(n2) 降为 O(nlog⁡n)O(n\log n)O(nlogn)

  1. 将多项式系数补零到 2k2^k2k 长度
  2. 分别 NTT
  3. 点值相乘
  4. 逆 NTT

5.2 多项式求逆

给定 A(x)A(x)A(x),求 B(x)B(x)B(x) 使 A(x)B(x)≡1(modxn)A(x)B(x) \equiv 1 \pmod{x^n}A(x)B(x)1(modxn)

牛顿迭代法:已知 B0(x)B_0(x)B0(x) 满足 A(x)B0(x)≡1(modx⌈n/2⌉)A(x)B_0(x) \equiv 1 \pmod{x^{\lceil n/2 \rceil}}A(x)B0(x)1(modxn/2)
则:
B(x)≡2B0(x)−A(x)B02(x)(modxn)B(x) \equiv 2B_0(x) - A(x)B_0^2(x) \pmod{x^n}B(x)2B0(x)A(x)B02(x)(modxn)

代码在 ntt2 中实现求逆:

a[0] = pow2(g[0], mod-2);  // 初始逆元
while (nn <= n-m+1) {
    nn *= 2;
    nn2 = pow2(nn*2ll % mod, mod-2);
    // 计算 A(x)B_0(x)
    ntt(nn*2, a, 1);
    ntt(nn*2, b, 1);
    for (int i = 0; i <= nn*2-1; i++) {
        a[i] = a[i] * (2ll - b[i] * a[i] % mod + mod) % mod;
    }
    ntt(nn*2, a, -1);
    for (int i = 0; i <= nn*2-1; i++) {
        a[i] = a[i] * nn2 % mod;
        if (i >= nn) a[i] = 0;
    }
}

5.3 多项式求导

A(x)=∑i=0naixiA(x) = \sum_{i=0}^n a_i x^iA(x)=i=0naixi,则 A′(x)=∑i=1niaixi−1A'(x) = \sum_{i=1}^n i a_i x^{i-1}A(x)=i=1niaixi1

代码中:

for (int i = 1; i <= nn; i++) {
    ff[i-1] = v[1][i] * i % mod;  // M'(x) 系数
}

5.4 多项式取模

给定 F(x)F(x)F(x)G(x)G(x)G(x),求 Q(x)Q(x)Q(x)R(x)R(x)R(x) 使:
F(x)=Q(x)G(x)+R(x),deg⁡R<deg⁡GF(x) = Q(x)G(x) + R(x), \quad \deg R < \deg GF(x)=Q(x)G(x)+R(x),degR<degG

算法:

  1. n=deg⁡Fn = \deg Fn=degF, m=deg⁡Gm = \deg Gm=degG
  2. 反转:FR(x)=xnF(1/x)F_R(x) = x^n F(1/x)FR(x)=xnF(1/x), GR(x)=xmG(1/x)G_R(x) = x^m G(1/x)GR(x)=xmG(1/x)
  3. GR(x)G_R(x)GR(x)xn−m+1x^{n-m+1}xnm+1 的逆 H(x)H(x)H(x)
  4. QR(x)=FR(x)H(x) mod xn−m+1Q_R(x) = F_R(x)H(x) \bmod x^{n-m+1}QR(x)=FR(x)H(x)modxnm+1
  5. 反转得 Q(x)Q(x)Q(x)
  6. R(x)=F(x)−Q(x)G(x)R(x) = F(x) - Q(x)G(x)R(x)=F(x)Q(x)G(x)

代码在 ntt2 中实现,用于多点求值。

六、分治构建多项式乘积

6.1 构建 Ml,r(x)=∏i=lr(x−xi)M_{l,r}(x) = \prod_{i=l}^r (x - x_i)Ml,r(x)=i=lr(xxi)

ntt0 函数:

inline void ntt0(long long n1, long long *a1, long long b1, long long l, long long r) {
    if (l == r) {  // 叶子节点
        a1[0] = (-xx[l] + mod) % mod;  // -x_i
        a1[1] = 1;                     // 系数1
        v[b1].push_back((-xx[l] + mod) % mod);
        v[b1].push_back(1);
        return;
    }
    long long mid = (l+r)/2;
    long long ax[n1*2+1], ay[n1*2+1];
    ntt0(n1/2, ax, b1*2, l, mid);
    ntt0(n1/2, ay, b1*2+1, mid+1, r);
    
    // 合并:多项式乘法
    ntt(n1*2, ax, 1);
    ntt(n1*2, ay, 1);
    for (int i = 0; i <= n1*2-1; i++) ax[i] = ax[i] * ay[i] % mod;
    ntt(n1*2, ax, -1);
    long long n12 = pow2(n1*2ll % mod, mod-2);
    for (int i = 0; i <= n1*2-1; i++) {
        ax[i] = ax[i] * n12 % mod;
        v[b1].push_back(ax[i]);  // 存储到vector
        a1[i] = ax[i];
    }
}

结果存在 v[b1] 中,b1 是节点编号。

七、多点求值

7.1 原理

计算 F(xi)F(x_i)F(xi) 对每个 xix_ixi
F(x) mod (x−xi)=F(xi)F(x) \bmod (x-x_i) = F(x_i)F(x)mod(xxi)=F(xi)

对区间 [l,r][l,r][l,r],设 Gl,r(x)=∏i=lr(x−xi)G_{l,r}(x) = \prod_{i=l}^r (x - x_i)Gl,r(x)=i=lr(xxi)
计算 R(x)=F(x) mod Gl,r(x)R(x) = F(x) \bmod G_{l,r}(x)R(x)=F(x)modGl,r(x),则:
∀i∈[l,r],F(xi)=R(xi)\forall i \in [l,r], \quad F(x_i) = R(x_i)i[l,r],F(xi)=R(xi)
因为 F(x)=Q(x)Gl,r(x)+R(x)F(x) = Q(x)G_{l,r}(x) + R(x)F(x)=Q(x)Gl,r(x)+R(x),而 Gl,r(xi)=0G_{l,r}(x_i)=0Gl,r(xi)=0

7.2 实现

ntt2 函数:

inline void ntt2(long long n1, long long *a1, long long b1, long long l, long long r) {
    if (l + 2000 >= r) {  // 小范围暴力
        for (int i = l; i <= r; i++) ui[i-l+1] = 1;
        for (int i = 0; i <= n1/2; i++) {
            for (int u = l; u <= r; u++) {
                yy[u] = (yy[u] + ui[u-l+1] * a1[i]);
                ui[u-l+1] = ui[u-l+1] * xx[u] % mod;
            }
            if (i % 6 == 0) {  // 每6次取模一次,优化
                for (int u = l; u <= r; u++) yy[u] %= mod;
            }
        }
        for (int u = l; u <= r; u++) yy[u] %= mod;
        return;
    }
    // 大范围:分治取模
    long long mid = (l+r)/2;
    // 取左子树多项式
    for (int i = 0; i < v[b1*2].size(); i++) g[i] = v[b1*2][i];
    // 计算 a1 mod g
    // ... 多项式取模代码
    ntt2(n1/2, cff, b1*2, l, mid);    // 递归左子树
    ntt2(n1/2, cff2, b1*2+1, mid+1, r); // 递归右子树
}

八、快速插值合并

8.1 原理

已知 wi=yi/M′(xi)w_i = y_i / M'(x_i)wi=yi/M(xi),求:
fl,r(x)=∑i=lrwi∏j=l,j≠ir(x−xj)f_{l,r}(x) = \sum_{i=l}^r w_i \prod_{j=l, j\neq i}^r (x - x_j)fl,r(x)=i=lrwij=l,j=ir(xxj)

递推关系:
fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,midf_{l,r} = f_{l,mid} \cdot M_{mid+1,r} + f_{mid+1,r} \cdot M_{l,mid}fl,r=fl,midMmid+1,r+fmid+1,rMl,mid

8.2 实现

ntt3 函数:

inline void ntt3(long long n1, long long *f1, long long *m1, long long l, long long r) {
    if (l == r) {  // 叶子
        f1[0] = uq[l];           // w_l
        m1[0] = (mod - xq[l]) % mod;  // -x_l
        m1[1] = 1;               // 系数1
        return;
    }
    long long mid = (l+r)/2;
    long long fx[n1*2+1], fy[n1*2+1], mx[n1*2+1], my[n1*2+1];
    ntt3(n1/2, fx, mx, l, mid);
    ntt3(n1/2, fy, my, mid+1, r);
    
    // f = fx*my + fy*mx
    // m = mx*my
    long long nn = n1*2;
    ntt(nn, fx, 1); ntt(nn, fy, 1);
    ntt(nn, mx, 1); ntt(nn, my, 1);
    long long nn2 = pow2(nn, mod-2);
    
    for (int i = 0; i <= nn; i++) f1[i] = (fx[i] * my[i]) % mod;
    ntt(nn, f1, -1);
    for (int i = 0; i <= nn; i++) f1[i] = f1[i] * nn2 % mod;
    
    long long fg[nn+1];
    for (int i = 0; i <= nn; i++) fg[i] = (fy[i] * mx[i]) % mod;
    ntt(nn, fg, -1);
    for (int i = 0; i <= nn; i++) fg[i] = fg[i] * nn2 % mod;
    
    for (int i = 0; i <= nn; i++) f1[i] = (f1[i] + fg[i]) % mod;
    
    for (int i = 0; i <= nn; i++) m1[i] = mx[i] * my[i] % mod;
    ntt(nn, m1, -1);
    for (int i = 0; i <= nn; i++) m1[i] = m1[i] * nn2 % mod;
}

九、主函数流程

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        scanf("%lld%lld", &xq[i], &yq[i]);
        xx[i] = xq[i];
    }
    
    // 1. 构造 M(x) = ∏(x - x_i)
    ntt0(nn/2, qwerty, 1, 1, m);
    
    // 2. 求导得到 M'(x)
    for (int i = 1; i <= nn; i++) {
        ff[i-1] = v[1][i] * i % mod;
    }
    
    // 3. 多点求值计算 M'(x_i)
    ntt2(nn, ff, 1, 1, m);
    
    // 4. 计算权重 w_i = y_i / M'(x_i)
    for (int i = 1; i <= m; i++) {
        uq[i] = yq[i] * pow2(yy[i], mod-2) % mod;
    }
    
    // 5. 分治合并得到最终多项式
    ntt3(nn/2, fe, me, 1, m);
    
    // 输出结果
    for (int i = 0; i <= n-1; i++) {
        cout << fe[i] << " ";
    }
    return 0;
}

十、关键细节

10.1 数组大小

  • 多项式乘法后次数为两多项式次数之和
  • 递归深度为 log⁡n\log nlogn
  • 原代码开 600005600005600005 大小,因为 n≤100000n \leq 100000n100000,需要 2⌈log⁡2n⌉×22^{\lceil \log_2 n \rceil} \times 22log2n×2 以上

10.2 优化技巧

  1. 小范围暴力:区间较小时直接霍纳法求值
  2. 记忆化旋转因子uv 数组存储 ωn\omega_nωnωn−1\omega_n^{-1}ωn1
  3. Barrett Reduction:加速取模
  4. 延迟取模:累加多次后再取模(i%6==0

10.3 注意事项

  1. 所有运算在模 998244353998244353998244353 下进行
  2. 逆元不存在时(M′(xi)=0M'(x_i)=0M(xi)=0)不会发生,因为 xix_ixi 互不相同
  3. 多项式次数不足 n−1n-1n1 时高位补零
  4. 输出系数从低次到高次

十一、复杂度分析

  • 构建 M(x)M(x)M(x)O(nlog⁡2n)O(n \log^2 n)O(nlog2n)
  • 求导:O(n)O(n)O(n)
  • 多点求值:O(nlog⁡2n)O(n \log^2 n)O(nlog2n)
  • 计算权重:O(nlog⁡mod)O(n \log mod)O(nlogmod)
  • 插值合并:O(nlog⁡2n)O(n \log^2 n)O(nlog2n)
    总复杂度:O(nlog⁡2n)O(n \log^2 n)O(nlog2n)

十二、总结

快速插值算法结合了:

  1. 分治思想
  2. NTT 加速多项式乘法
  3. 多项式求逆
  4. 多项式取模
  5. 多点求值技巧

每个部分都需要精心实现,注意边界条件和优化细节。通过分治,将 O(n2)O(n^2)O(n2) 的拉格朗日插值优化到 O(nlog⁡2n)O(n \log^2 n)O(nlog2n),是多项式高级算法的经典应用。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐