This documentation is automatically generated by online-judge-tools/verification-helper
#include "Convolution/FastFourierTransform.hpp"
$2$ つの実数係数多項式 $f(x), g(x)$ に対して、積 $f(x)g(x)$ を計算します。
計算量は $\Theta(N\log{N})$ (積の次数を $N$ ) です。
基数 $4$ の時間間引き Cooley-Tukey 型アルゴリズムを用いています。
計算を行う多項式積の最大次数 $N$ が分かっている場合、max_n
に $N + 1$ を指定します。
事前に計算テーブルを構築することにより少しだけ高速になります。
実装の都合上、事前計算テーブルを使用した積の計算には ()
演算子を使用して下さい。
制約
T
は int
, long long int
, float
, double
計算量
静的メンバ関数 multiply(...)
と仕様は同じです。
事前計算テーブルを使用することによって少しだけ高速になります。
制約
T
は int
, long long int
, float
, double
計算量
$f(x) := \sum_{i=0}^{|a| - 1} a[i] x^i$, $g(x) := \sum_{i=0}^{|b| - 1} b[i] x^i$ として、
積 $f(x)g(x) = \sum_{i=0}^{|a| + |b| - 2} c[i] x^i$ となるような大きさ $|a| + |b| - 1$ の列 $c[i]$ を double
型で返します。
$a, b$ いずれかが空である場合は、空の列を返します。
制約
T
は int
, long long int
, float
, double
計算量
#include <bits/stdc++.h>
#include "Convolution/FastFourierTransform.hpp"
using namespace std;
int main() {
vector<int> a({0, 1, 2, 3}), b({2, 3, 4});
auto c = FastFourierTransform::multiply(a, b);
cout << "size = " << c.size() << endl; // 6
for (int i = 0; i < c.size(); ++i) cout << c[i] << " "; cout << endl; // 0 2 7 16 17 12
// 0 0 0
// 2 3 4
// 4 6 8
// 6 9 12
// ==============
// 0 2 7 16 17 12
// c は double 型であることに注意
vector<int> d{};
cout << "size = " << FastFourierTransform::multiply(a, d).size() << endl; // 0
vector<double> ad({1.5, 2}), bd({2, 5.5, 4});
auto cd = FastFourierTransform::multiply(ad, bd);
cout << "size = " << cd.size() << endl; // 4
for (int i = 0; i < cd.size(); ++i) cout << cd[i] << " "; cout << endl; // 3 12.25 17 8
// 3 8.25 6
// 4 11 8
// ==============
// 3 12.25 17 8
auto fft = FastFourierTransform(6); // zeta 配列を使い回す
auto pc = fft(a, b);
auto pcd = fft(ad, bd);
cout << "c == pc : " << boolalpha << (c == pc) << endl; // true
cout << "cd == pcd : " << boolalpha << (cd == pcd) << endl; // true
}
離散フーリエ変換の基礎については既知であるとします。
多項式 $f(x)$ の $i$ 次の係数を $f[k]$ と書きます。
つまり、$f(x)$ の次数を $d$ とすると、 $f(x) = \sum_{i=0}^d f[i]x^i$ です。
$\zeta_N$ は $1$ の $N$ 乗根を表します。
また、複素数 $z$ の複素共役を $\overline{z}$ 、虚数単位を $j$ と表します。
以下 $N$ が $2$ の冪として次数 $N-1$ の多項式を扱います。
$N-1$ 次多項式 $f(x)$ の離散フーリエ変換 $F(X)$ を $F(X) = \sum_{i=0}^{N-1} f(\zeta_N^i) X^i$ で定義します。
$f[k] = \frac{1}{N} F(\zeta_N^{-k})$ です(逆離散フーリエ変換)。詳しく知りたい方は参考欄のサイトを参照してください。
$F(\zeta_N^{-k}) = \sum_{i=0}^{N-1} F[i] \zeta_N^{-ki} = \overline{\sum_{i=0}^{N-1} \overline{F[i]} \zeta_N^{ki}}$ より、$F(X)$ の係数の複素共役を取った後で離散フーリエ変換を行い $N$ で割ることにより $f(x)$ が復元できます。
$2$ つの実数係数多項式 $f(x), g(x)$ のフーリエ変換 $F(X), G(X)$ を $1$ 回の離散フーリエ変換で求めることができます。
まず、$F[N] = F[0]$ として、$F[N-t] = \overline{F[t]}$ が成り立ちます。
これは、$f(x)$ が実数係数であることから $f[k] = \overline{f[k]}$ が成り立ち、$F[N-t] = \sum_{i=0}^{N-1} f[i] \zeta_N^{N-t} = \overline{\sum_{i=0}^{N-1} \overline{f[i] \zeta_N^{-t}}} = \overline{\sum_{i=0}^{N-1} f[i] \zeta_N^t} = F[t]$ より分かります。
ここで、$h(x) := f(x) + j g(x)$ として多項式 $h(x)$ を定義します。$h(x)$ の離散フーリエ変換 $H(X)$ に関して $H(X) = F(X) + j G(X)$ です。
一方、$\overline{H[N-t]} = \overline{F(N-t)} - \overline{G[N-t]} j = F[t] - j G[t]$ も成り立ちます。
これら $2$ つの等式を連立して解くことにより次の式が得られます。
\[F[t] = \frac{1}{2} (H[t] + \overline{H[N-t]})\] \[G[t] = \frac{1}{2i} (H[t] - \overline{H[N-t]})\]これより、$C[t] := F[t] G[t]$ とすると、$C[t] = \frac{1}{4i} (H[t] + \overline{H[N-t]})(H[t] - \overline{H[N-t]})$ です。
$C[t]$ に対して逆離散フーリエ変換を行うことにより積 $c(x) = f(x) g(x)$ が計算できます。
周波数間引きにより、$2$ つの $N/2 - 1$ 次の多項式 $c_e, c_o$ を $c_e[k] = c[2k], c_o[k] = c[2k+1]$ として定めると、
\[C[t] = C_e[t] + \zeta_N^t C_o[t]\] \[C[t + N/2] = C_e[t] - \zeta_N^t C_o[t]\]です。
これらを $C_e[t], C_o[t]$ について解くことにより、
\[C_e[t] = \frac{1}{2} (C[t] + C[t + N/2])\] \[C_o[t] = \frac{1}{2} \overline{\zeta_N^{t}} (C[t] - C[t + N/2])\]が得られ、$C(X)$ から $C_o(X), C_e(X)$ が求まることが分かります。
ここで、$p(x) := c_e(x) + j c_o(x)$ として多項式 $p(x)$ を定義します。
$p(x)$ の離散フーリエ変換 $P(X) = C_e(X) + j C_o(X)$ を逆離散フーリエ変換することにより $p(x) = c_e(x) + j c_o(x)$ が復元できることと、$c_e(x), c_o(x)$ は実数係数多項式なので $p[i]$ の実部、虚部がそれぞれ $c_e[i], c_o[i]$ に対応することに注意しておきます。
$C(X)$ は先程の “$2$ つの実数係数多項式の離散フーリエ変換を $1$ 回の離散フーリエ変換で行うテク” で求めることができるので、$C_e(X), C_o(X)$ は計算が可能です。$C_e(X), C_o(X)$ から $P(X)$ を構築して逆離散フーリエ変換を行うことにより $c_e(x), c_o(x)$ すなわち、$c(x)$ が計算できます。
したがって、長さが $N$ の離散フーリエ変換と長さが半分($N/2$) の逆離散フーリエ変換を $1$ 回ずつ行うことにより積 $f(x) g(x)$ が得られました。
$\overline{C[N-t]} = \overline{F[N-t]G[N-t]} = F[t]G[t] = C[t]$ より、 $\overline{C[t + N/2]} = C[N - (t + N/2)] = C[N/2 - t]$ です。
これを代入することにより、
\[C_e[t] = \frac{1}{2} (C[t] + \overline{C[N/2 - t]})\] \[C_o[t] = \frac{1}{2} \overline{\zeta_N^{t}} (C[t] - \overline{C[N/2 - t]})\]が得られます。
最後に、実装を読むときのためのメモを以下に記しておきます。
基数 $4$ の周波数間引き Cooley-Tukey 型アルゴリズムの実装も書いたので載せておきます。
static void fft(std::vector<complex_type> & a, uint32 log_n, const std::vector<complex_type> & zeta, uint32 log_z) {
const uint32 n = a.size();
auto zeta_f = [&](uint32 d, uint32 p) {
return zeta[p << (log_z - d)];
};
for (uint32 w = n, c = log_n; w >= 4; w >>= 2, c -= 2) {
const uint32 s = w >> 2;
for (uint32 p = 0; p < n; p += w) {
for (uint32 i = 0; i < s; ++i) {
const uint32 pos = p + i;
const complex_type a0 = a[pos], a2 = a[pos + (s << 1)];
const complex_type ep = a0 + a2, en = (a0 - a2) * zeta_f(c, i, zeta, log_z);
const complex_type a1 = a[pos + s], a3 = a[pos + w - s];
const complex_type op = a1 + a3, on = ie * (a1 - a3) * zeta_f(c, i, zeta, log_z);
a[pos] = ep + op;
a[pos + s] = (ep - op) * zeta_f(c - 1, i, zeta, log_z);
a[pos + (s << 1)] = en + on;
a[pos + w - s] = (en - on) * zeta_f(c - 1, i, zeta, log_z);
}
}
}
if (log_n & 1) {
for (uint32 i = 0; i < n; i += 2) {
const complex_type x = a[i], y = a[i + 1];
a[i] = x + y;
a[i + 1] = x - y;
}
}
bit_reverse(a);
}
TODO: 周波数間引きを用いて bit_reverese を消せるか考えてみる
TODO: zeta 配列を線形に見るように上手く変形する
2020/05/01: https://qiita.com/ageprocpp/items/0d63d4ed80de4a35fe79
2020/08/01: http://wwwa.pikara.ne.jp/okojisan/stockham/cooley-tukey.html
2021/04/16: http://www.ocw.titech.ac.jp/index.php?module=General&action=T0300&JWC=202002382&lang=JA&vid=03
#ifndef INCLUDE_GUARD_FAST_FOURIER_TRANSFORM_HPP
#define INCLUDE_GUARD_FAST_FOURIER_TRANSFORM_HPP
#include <vector>
#include <complex>
#include <algorithm>
#include <cstdint>
/**
* @brief https://tkmst201.github.io/Library/Convolution/FastFourierTransform.hpp
*/
struct FastFourierTransform {
using value_type = double;
using complex_type = std::complex<value_type>;
private:
using uint32 = std::uint32_t;
constexpr static value_type pi = 3.1415926535897932384626433832795028841972;
constexpr static complex_type ie{0, 1};
uint32 mlog_n;
std::vector<complex_type> zeta;
public:
explicit FastFourierTransform(uint32 max_n) : mlog_n(calc_l2(max_n)), zeta(zeta_(mlog_n)) {}
template<typename T>
std::vector<value_type> operator ()(const std::vector<T> & a, const std::vector<T> & b) const {
if (a.empty() || b.empty()) return {};
if (a.size() == 1 && b.size() == 1) return {static_cast<value_type>(a.front()) * b.front()};
assert((a.size() + b.size() - 1) <= (1u << mlog_n));
return multiply_sub(a, b, zeta, mlog_n);
}
template<typename T>
static std::vector<value_type> multiply(const std::vector<T> & a, const std::vector<T> & b) {
if (a.empty() || b.empty()) return {};
if (a.size() == 1 && b.size() == 1) return {static_cast<value_type>(a.front()) * b.front()};
const uint32 log_n = calc_l2(a.size() + b.size() - 1);
const std::vector<complex_type> zeta = zeta_(log_n);
return multiply_sub(a, b, zeta, log_n);
}
private:
template<typename T>
static std::vector<value_type> multiply_sub(const std::vector<T> & a, const std::vector<T> & b, const std::vector<complex_type> & zeta, uint32 log_z) {
const uint32 n_ = a.size() + b.size() - 1;
const uint32 log_n = calc_l2(n_), n = 1u << log_n, m = n >> 1;
std::vector<complex_type> h(n);
for (uint32 i = 0; i < a.size(); ++i) h[i].real(a[i]);
for (uint32 i = 0; i < b.size(); ++i) h[i].imag(b[i]);
fft(h, log_n, zeta, log_z);
std::vector<complex_type> p(m);
{
const value_type cf = h[0].real() * h[0].imag();
const value_type cb = h[m].real() * h[m].imag();
p[0] = complex_type(cf + cb, -(cf - cb)) / 2.0;
}
for (uint32 i = 1; i <= (m >> 1); ++i) {
const complex_type cf = -(h[i] + std::conj(h[n - i])) * (h[i] - std::conj(h[n - i])) * ie;
const complex_type cb = -(h[m - i] + std::conj(h[m + i])) * (h[m - i] - std::conj(h[m + i])) * ie;
p[i] = std::conj((cf + std::conj(cb) + (cf - std::conj(cb)) * std::conj(zeta_f(log_n, i, zeta, log_z)) * ie)) / 8.0;
if (i != m / 2) p[m - i] = std::conj((cb + std::conj(cf)) + (cb - std::conj(cf)) * std::conj(zeta_f(log_n, m - i, zeta, log_z)) * ie) / 8.0;
}
fft(p, log_n - 1, zeta, log_z);
std::vector<value_type> res;
res.reserve(n_);
for (uint32 i = 0; i < m; ++i) {
if ((i << 1) < n_) res.emplace_back(p[i].real() / static_cast<value_type>(m));
if ((i << 1 | 1) < n_) res.emplace_back(-p[i].imag() / static_cast<value_type>(m));
}
return res;
}
static void fft(std::vector<complex_type> & a, uint32 log_n, const std::vector<complex_type> & zeta, uint32 log_z) {
const uint32 n = a.size(), m = n >> 1;
bit_reverse(a);
for (uint32 w = 4, c = 2; w <= n; w <<= 2, c += 2) {
const uint32 s = w >> 2;
for (uint32 p = 0; p < n; p += w) {
for (uint32 i = 0; i < s; ++i) {
const uint32 pos = p + i;
const complex_type a0 = a[pos], a2 = a[pos + s] * zeta_f(c - 1, i, zeta, log_z);
const complex_type a1 = a[pos + (s << 1)] * zeta_f(c, i, zeta, log_z), a3 = a[pos + w - s] * zeta_f(c, 3 * i, zeta, log_z);
const complex_type lp = a0 + a2, rp = a1 + a3, ln = a0 - a2, rn = a1 - a3;
a[pos] = lp + rp;
a[pos + (s << 1)] = lp - rp;
a[pos + s] = ln + rn * ie;
a[pos + w - s] = ln - rn * ie;
}
}
}
if (~log_n & 1) return;
for (uint32 i = 0; i < m; ++i) {
const complex_type x = a[i], y = a[i + m] * zeta_f(log_n, i, zeta, log_z);
a[i] = x + y;
a[i + m] = x - y;
}
}
static uint32 calc_l2(uint32 n_) noexcept {
uint32 log_n = 0;
for (uint32 n = 1; n < n_; n <<= 1) ++log_n;
return log_n;
}
static void bit_reverse(std::vector<complex_type> & a) noexcept {
const uint32 N = a.size();
for (uint32 i = 1, j = 0; i < N - 1; ++i) {
for (uint32 k = N >> 1; k > (j ^= k); k >>= 1);
if (i < j) std::swap(a[i], a[j]);
}
}
static std::vector<complex_type> zeta_(uint32 log_n) {
if (log_n == 0) return {};
std::vector<complex_type> zeta;
zeta.reserve(1 << (log_n - 1));
zeta.emplace_back(1, 0);
for (uint32 i = 0; i < (log_n - 1); ++i) {
const complex_type t = std::polar<value_type>(1, 2.0 * pi / static_cast<value_type>(1 << (log_n - i)));
zeta.emplace_back(t);
for (uint32 j = 1; j < (1u << i); ++j) zeta.emplace_back(zeta[j] * t);
}
return zeta;
}
static complex_type zeta_f(uint32 d, uint32 p, const std::vector<complex_type> & zeta, uint32 log_z) noexcept {
const uint32 idx = p << (log_z - d);
return idx < zeta.size() ? zeta[idx] : -zeta[idx - zeta.size()];
}
};
#endif // INCLUDE_GUARD_FAST_FOURIER_TRANSFORM_HPP
#line 1 "Convolution/FastFourierTransform.hpp"
#include <vector>
#include <complex>
#include <algorithm>
#include <cstdint>
/**
* @brief https://tkmst201.github.io/Library/Convolution/FastFourierTransform.hpp
*/
struct FastFourierTransform {
using value_type = double;
using complex_type = std::complex<value_type>;
private:
using uint32 = std::uint32_t;
constexpr static value_type pi = 3.1415926535897932384626433832795028841972;
constexpr static complex_type ie{0, 1};
uint32 mlog_n;
std::vector<complex_type> zeta;
public:
explicit FastFourierTransform(uint32 max_n) : mlog_n(calc_l2(max_n)), zeta(zeta_(mlog_n)) {}
template<typename T>
std::vector<value_type> operator ()(const std::vector<T> & a, const std::vector<T> & b) const {
if (a.empty() || b.empty()) return {};
if (a.size() == 1 && b.size() == 1) return {static_cast<value_type>(a.front()) * b.front()};
assert((a.size() + b.size() - 1) <= (1u << mlog_n));
return multiply_sub(a, b, zeta, mlog_n);
}
template<typename T>
static std::vector<value_type> multiply(const std::vector<T> & a, const std::vector<T> & b) {
if (a.empty() || b.empty()) return {};
if (a.size() == 1 && b.size() == 1) return {static_cast<value_type>(a.front()) * b.front()};
const uint32 log_n = calc_l2(a.size() + b.size() - 1);
const std::vector<complex_type> zeta = zeta_(log_n);
return multiply_sub(a, b, zeta, log_n);
}
private:
template<typename T>
static std::vector<value_type> multiply_sub(const std::vector<T> & a, const std::vector<T> & b, const std::vector<complex_type> & zeta, uint32 log_z) {
const uint32 n_ = a.size() + b.size() - 1;
const uint32 log_n = calc_l2(n_), n = 1u << log_n, m = n >> 1;
std::vector<complex_type> h(n);
for (uint32 i = 0; i < a.size(); ++i) h[i].real(a[i]);
for (uint32 i = 0; i < b.size(); ++i) h[i].imag(b[i]);
fft(h, log_n, zeta, log_z);
std::vector<complex_type> p(m);
{
const value_type cf = h[0].real() * h[0].imag();
const value_type cb = h[m].real() * h[m].imag();
p[0] = complex_type(cf + cb, -(cf - cb)) / 2.0;
}
for (uint32 i = 1; i <= (m >> 1); ++i) {
const complex_type cf = -(h[i] + std::conj(h[n - i])) * (h[i] - std::conj(h[n - i])) * ie;
const complex_type cb = -(h[m - i] + std::conj(h[m + i])) * (h[m - i] - std::conj(h[m + i])) * ie;
p[i] = std::conj((cf + std::conj(cb) + (cf - std::conj(cb)) * std::conj(zeta_f(log_n, i, zeta, log_z)) * ie)) / 8.0;
if (i != m / 2) p[m - i] = std::conj((cb + std::conj(cf)) + (cb - std::conj(cf)) * std::conj(zeta_f(log_n, m - i, zeta, log_z)) * ie) / 8.0;
}
fft(p, log_n - 1, zeta, log_z);
std::vector<value_type> res;
res.reserve(n_);
for (uint32 i = 0; i < m; ++i) {
if ((i << 1) < n_) res.emplace_back(p[i].real() / static_cast<value_type>(m));
if ((i << 1 | 1) < n_) res.emplace_back(-p[i].imag() / static_cast<value_type>(m));
}
return res;
}
static void fft(std::vector<complex_type> & a, uint32 log_n, const std::vector<complex_type> & zeta, uint32 log_z) {
const uint32 n = a.size(), m = n >> 1;
bit_reverse(a);
for (uint32 w = 4, c = 2; w <= n; w <<= 2, c += 2) {
const uint32 s = w >> 2;
for (uint32 p = 0; p < n; p += w) {
for (uint32 i = 0; i < s; ++i) {
const uint32 pos = p + i;
const complex_type a0 = a[pos], a2 = a[pos + s] * zeta_f(c - 1, i, zeta, log_z);
const complex_type a1 = a[pos + (s << 1)] * zeta_f(c, i, zeta, log_z), a3 = a[pos + w - s] * zeta_f(c, 3 * i, zeta, log_z);
const complex_type lp = a0 + a2, rp = a1 + a3, ln = a0 - a2, rn = a1 - a3;
a[pos] = lp + rp;
a[pos + (s << 1)] = lp - rp;
a[pos + s] = ln + rn * ie;
a[pos + w - s] = ln - rn * ie;
}
}
}
if (~log_n & 1) return;
for (uint32 i = 0; i < m; ++i) {
const complex_type x = a[i], y = a[i + m] * zeta_f(log_n, i, zeta, log_z);
a[i] = x + y;
a[i + m] = x - y;
}
}
static uint32 calc_l2(uint32 n_) noexcept {
uint32 log_n = 0;
for (uint32 n = 1; n < n_; n <<= 1) ++log_n;
return log_n;
}
static void bit_reverse(std::vector<complex_type> & a) noexcept {
const uint32 N = a.size();
for (uint32 i = 1, j = 0; i < N - 1; ++i) {
for (uint32 k = N >> 1; k > (j ^= k); k >>= 1);
if (i < j) std::swap(a[i], a[j]);
}
}
static std::vector<complex_type> zeta_(uint32 log_n) {
if (log_n == 0) return {};
std::vector<complex_type> zeta;
zeta.reserve(1 << (log_n - 1));
zeta.emplace_back(1, 0);
for (uint32 i = 0; i < (log_n - 1); ++i) {
const complex_type t = std::polar<value_type>(1, 2.0 * pi / static_cast<value_type>(1 << (log_n - i)));
zeta.emplace_back(t);
for (uint32 j = 1; j < (1u << i); ++j) zeta.emplace_back(zeta[j] * t);
}
return zeta;
}
static complex_type zeta_f(uint32 d, uint32 p, const std::vector<complex_type> & zeta, uint32 log_z) noexcept {
const uint32 idx = p << (log_z - d);
return idx < zeta.size() ? zeta[idx] : -zeta[idx - zeta.size()];
}
};