目录

P4717-【模板】快速沃尔什变换 (FWT)

P4717-【模板】快速沃尔什变换 (FWT)

题目:

题目描述:

给定长度为 $2^n$ 两个序列 $A, B$,设

$$C_i=\sum_{j\oplus k = i}A_j \times B_k$$

分别当 $\oplus$ 是 or, and, xor 时求出 $C$

输入格式:

第一行一个数 n。 第二行$2^n$个数$A_0.. A_{2^n-1}$ 第三行$2^n$个数$B_0.. B_{2^n-1}$

输出格式:

三行每行$2^n$个数,分别代表$\oplus$是 or, and, xor 时$C_0.. C_{2^n-1}$的值$\bmod\ 998244353$

样例:

样例输入 1:

1
2
3
2
2 4 6 8
1 3 5 7

样例输出 1:

1
2
3
2 22 46 250
88 64 112 56
100 92 68 60

思路:

实现:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
#include "ybwhead/ios.h"
const int maxn = (1 << 17) + 10;
int n;
long long a[maxn], b[maxn], f[maxn], g[maxn];
const int mod = 998244353;
const int inv = (mod + 1) >> 1;
void mul()
{
    for (int i = 1; i <= n; i++)
        f[i] *= g[i], f[i] %= mod;
}
void print()
{
    for (int i = 1; i <= n; i++)
        yout << f[i] << " ";
    puts("");
}
void pre()
{
    memcpy(f, a, sizeof(f));
    memcpy(g, b, sizeof(g));
}
void __AND(long long f[], int x = 1)
{
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
    {
        for (int i = 1; i <= n; i += o)
        {
            for (int j = 0; j < k; j++)
            {
                f[i + j] += f[i + j + k] * x + mod;
                while (f[i + j] >= mod)
                    f[i + j] -= mod;
            }
        }
    }
}
void __OR(long long f[], int x = 1)
{
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
    {
        for (int i = 1; i <= n; i += o)
        {
            for (int j = 0; j < k; j++)
            {
                f[i + j + k] += f[i + j] * x + mod;
                while (f[i + j + k] >= mod)
                    f[i + j + k] -= mod;
            }
        }
    }
}
void __XOR(long long f[], int x = 1)
{
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
    {
        for (int i = 1; i <= n; i += o)
        {
            for (int j = 0; j < k; j++)
            {
                f[i + j] += f[i + j + k];
                f[i + j + k] = f[i + j] - f[i + j + k] - f[i + j + k];
                f[i + j] *= x;
                f[i + j + k] *= x;
                f[i + j] = (f[i + j] % mod + mod) % mod;
                f[i + j + k] = (f[i + j + k] % mod + mod) % mod;
            }
        }
    }
}

int main()
{
    yin >> n;
    n = 1 << n;
    for (int i = 1; i <= n; i++)
        yin >> a[i];
    for (int i = 1; i <= n; i++)
        yin >> b[i];
    pre();
    __OR(f);
    __OR(g);
    mul();
    __OR(f, -1);
    print();

    pre();
    __AND(f);
    __AND(g);
    mul();
    __AND(f, -1);
    print();

    pre();
    __XOR(f);
    __XOR(g);
    mul();
    __XOR(f, inv);
    print();

    return 0;
}