目录

P4512-【模板】多项式除法

P4512-【模板】多项式除法

题目:

题目描述:

给定一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x)$ ,请求出多项式 $Q(x)$, $R(x)$,满足以下条件:

  • $Q(x)$ 次数为 $n-m$,$R(x)$ 次数小于 $m$
  • $F(x) = Q(x) * G(x) + R(x)$

所有的运算在模 $998244353$ 意义下进行。

输入格式:

第一行两个整数 $n$,$m$,意义如上。
第二行 $n+1$ 个整数,从低到高表示 $F(x)$ 的各个系数。
第三行 $m+1$ 个整数,从低到高表示 $G(x)$ 的各个系数。

输出格式:

第一行 $n-m+1$ 个整数,从低到高表示 $Q(x)$ 的各个系数。
第二行 $m$ 个整数,从低到高表示 $R(x)$ 的各个系数。
如果 $R(x)$ 不足 $m-1$ 次,多余的项系数补 $0$。

样例:

样例输入 1:

1
2
3
5 1
1 9 2 6 0 8
1 7

样例输出 1:

1
2
237340659 335104102 649004347 448191342 855638018
760903695

思路:

实现:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "ybwhead/ios.h"
int n, lim, l;
const int maxn = 1e6 + 10;
long long a[maxn], b[maxn];
int r[maxn];
const int mod = 998244353;
long long c[maxn];
long long ksm(long long a, int n)
{
    long long ans = 1;
    while (n)
    {
        if (n & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return ans;
}
const int g = 3, gi = 332748118;
void ntt(long long a[], long long x)
{
    for (int i = 0; i < lim; i++)
        if (i < r[i])
            swap(a[i], a[r[i]]);
    for (int o = 2, k = 1; o <= lim; o <<= 1, k <<= 1)
    {
        long long wn = ksm(x, (mod - 1) / (o));
        for (int i = 0; i < lim; i += o)
        {
            long long w = 1;
            for (int j = 0; j < k; j++, w = w * wn % mod)
            {
                long long x = a[i + j], y = a[i + j + k] * w % mod;
                a[j + i] = (x + y) % mod;
                a[j + k + i] = (x - y + mod) % mod;
            }
        }
    }
}
void calc(int n, long long b[])
{
    if (n == 1)
    {
        b[0] = ksm(a[0], mod - 2);
        // yout<<a[0]
        return;
    }
    calc((n + 1) >> 1, b);
    lim = 1;
    l = 0;
    while (lim < (n << 1))
        lim <<= 1, ++l;
    // yout<<lim<<" "<<a[1]<<endl;
    for (int i = 0; i < lim; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for (int i = 0; i < n; i++)
        c[i] = a[i];
    for (int i = n; i < lim; i++)
        c[i] = 0;
    ntt(c, g);
    ntt(b, g);
    for (int i = 0; i < lim; i++)
    {
        b[i] = (2 - c[i] * b[i] % mod + mod) % mod * b[i] % mod;
        // yout<<b[i]<<endl;
    }
    ntt(b, gi);
    long long inv = ksm(lim, mod - 2);
    for (int i = 0; i < n; i++)
        b[i] = (b[i] * inv) % mod;
    for (int i = n; i < lim; i++)
        b[i] = 0;
}
int m;
long long d[maxn];
long long aa[maxn];
int main()
{
    yin >> m >> n;
    ++m;
    ++n;
    for (int i = 0; i < m; i++)
        yin >> d[i];
    for (int i = 0; i < n; i++)
        yin >> a[i], aa[i] = a[i];
    reverse(a, a + n);
    for (int i = m - n + 1; i < m; i++)
        a[i] = 0;
    calc(m - n + 1, b);
    lim = 1, l = 0;
    while (lim <= (m << 1))
        lim <<= 1, ++l;
    reverse(d, d + m);
    for (int i = 0; i < lim; i++)
        c[i] = d[i];
    reverse(d, d + m);
    for (int i = 0; i < lim; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    ntt(c, g);
    ntt(b, g);
    for (int i = 0; i < lim; i++)
        c[i] *= b[i], c[i] %= mod;
    ntt(c, gi);
    reverse(c, c + m - n + 1);
    long long inv = ksm(lim, mod - 2);
    for (int i = 0; i < lim; i++)
        c[i] = (c[i] * inv) % mod;
    for (int i = 0; i < m - n + 1; i++)
        yout << c[i] << " ";
    for (int i = m - n + 1; i < lim; i++)
        c[i] = 0;
    yout << endl;
    for (int i = 0; i < m; i++)
        a[i] = aa[i];
    ntt(a, g);
    ntt(c, g);
    for (int i = 0; i < lim; i++)
        a[i] *= c[i], a[i] %= mod;
    ntt(a, gi);
    inv = ksm(lim, mod - 2);
    for (int i = 0; i < n - 1; i++)
    {
        yout << (d[i] - a[i] * inv % mod + mod) % mod << ' ';
    }
    return 0;
}