求问我的递归写法会很慢吗
查看原帖
求问我的递归写法会很慢吗
998707
spire001楼主2024/9/25 16:16

本人自己写出了看不懂非递归写法。。

#include <algorithm>
#include <cassert>
#include <iostream>

#define solve(x, sig, inv)                                                     \
  {                                                                            \
    init();                                                                    \
    FWT_##x<sig>(n, a);                                                        \
    FWT_##x<sig>(n, b);                                                        \
    mul();                                                                     \
    FWT_##x<inv>(n, a);                                                        \
    print();                                                                   \
  }

using namespace std;

typedef long long LL;

constexpr int N = 500010;
constexpr int mod = 998244353;

template <int sign> void FWT_or(int n, LL *f) {
  if (n == 1)
    return;
  const int m = n >> 1;
  FWT_or<sign>(m, f);
  FWT_or<sign>(m, f + m);
  for (int i = m; i != n; i++)
    (f[i] += sign * f[i - m]) %= mod;
  return;
}
template <int sign> void FWT_and(int n, LL *f) {
  if (n == 1)
    return;
  const int m = n >> 1;
  FWT_and<sign>(m, f);
  FWT_and<sign>(m, f + m);
  for (int i = 0; i != m; i++)
    (f[i] += sign * f[i + m]) %= mod;
  return;
}
template <int sign> void FWT_xor(int n, LL *f) {
  if (n == 1)
    return;
  const int m = n >> 1;
  FWT_xor<sign>(m, f);
  FWT_xor<sign>(m, f + m);
  for (int i = 0; i != m; i++) {
    LL tmp = f[i];
    f[i] = (f[i] + f[i + m]) * sign % mod;
    f[i + m] = (tmp - f[i + m]) * sign % mod;
  }
  return;
}

int n;
LL _a[N], _b[N], a[N], b[N];

void init() {
  copy(_a, _a + n, a);
  copy(_b, _b + n, b);
}
void mul() {
  for (int i = 0; i != n; i++)
    (a[i] *= b[i]) %= mod;
}
void print() {
  for (int i = 0; i != n; i++)
    cout << (a[i] % mod + mod) % mod << ' ';
  cout << endl;
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);

  cin >> n;
  n = 1 << n;

  for (int i = 0; i != n; i++)
    cin >> _a[i];
  for (int i = 0; i != n; i++)
    cin >> _b[i];

  solve(or, 1, -1);
  solve(and, 1, -1);
  solve(xor, 1, 499122177);

  return 0;
}
2024/9/25 16:16
加载中...