This documentation is automatically generated by online-judge-tools/verification-helper
#include "DataStructure/SegmentTree.hpp"
配列を扱うデータ構造です。
要素数 $N$ の配列に対し、1 点更新や区間に対する演算をそれぞれ $\Theta(\log{N})$、1 点取得を $\Theta(1)$ で行うことができます。
区間に対して一意に値が定まり、区間をまとめて計算できるような演算が扱えます。例: +
, xor
, min
, gcd
, 関数の合成
など。
区間更新を行いたい場合は、遅延伝搬セグメント木を使用してください。
SegmentTree(size_t n, const T & id_elem, const F & f)
SegmentTree(const vector<T> & v, const T & id_elem, const F & f)
size_t size()
void set(size_t i, const T & x)
const T & get(size_t i)
T fold(size_t l, size_t r)
T fold_all()
size_t max_right(size_t l, std::function<bool (const T &)> g)
size_t min_left(size_t r, std::function<bool (const T &)> g)
F
は二項演算 std::function<T (const T &, const T &)>
の略記です。
制約
F
の単位元は id_elem
id_elem
$)$ はモノイド要素数 $n$ で初期化します。
初期値は単位元 id_elem
です。
計算量
配列 v
で初期化します。
計算量
エラーは出ませんが、初期化しないまま使用した場合の動作は保証されていません。
計算量
以下、要素数 $N$ の配列 $A_0, A_1, \ldots, A_{N-1}$ を対象とします。 二項演算は $f$ です。
配列の要素数 $N$ を返します。
計算量
$A_i$ を $x$ に変更します。
制約
計算量
$A_i$ を返します。
制約
計算量
半開区間 $[l, r)$ の演算結果 $f(A_l, f(A_{l+1}, f(\ldots, f(A_{r-2}, A_{r-1}))\ldots)$ を返します。 $l = r$ のときは単位元を返します。
制約
計算量
$fold(0,N)$ の計算結果 $f(A_0, f(A_1, f(\ldots, f(A_{N-2}, A_{N-1}))\ldots))$ を返します。
計算量
$g(fold(l, r)) = true$ となるような最小の $r$ を返します。 $g(fold(l, N)) = true$ または $l = N$ のときは $N$ を返します。
制約
id_elem
$) = true$計算量
Verified
$g(fold(l, r)) = true$ となるような最大の $l$ を返します。 $g(fold(0, r)) = true$ または $l = 0$ のときは $0$ を返します。
制約
id_elem
$) = true$計算量
Verified
和を扱うセグメント木の例です。オーバーフローに注意してください。
総和が $2^{31}$ 以上になる場合は long long
を使いましょう。
#include <bits/stdc++.h>
#include "DataStructure/SegmentTree.hpp"
using namespace std;
int main() {
const vector<int> A {1, 2, 3, 0, 0, 0, 4, 5};
// 和を扱うセグメント木
SegmentTree<int> seg(A, 0, [](auto x, auto y) { return x + y; });
cout << "N = " << seg.size() << endl; // 8 (= N)
cout << "sum = " << seg.fold_all() << endl; // 15
cout << "sum[0, 2) = " << seg.fold(0, 2) << endl; // 3
cout << "sum[0, 0) = " << seg.fold(0, 0) << endl; // 0 (= id_elem)
// A[0..] で合計が 6 以下となるような最大の index を求める (index = r - 1)
int r = seg.max_right(0, [](auto x) { return x <= 6; });
cout << "r = " << r << endl; // 6
cout << "sum[0, r) = " << seg.fold(0, r) << " (<= 6)" << endl; // 6
// A[..4] で合計が 4 以下となるような最小の index を求める (index = l)
int l = seg.min_left(5, [](auto x) { return x <= 4; });
cout << "l = " << l << endl; // 2
cout << "sum[l, 5) = " << seg.fold(l, 5) << " (<= 4)" << endl; // 3
cout << "seg[2] = " << seg.get(2) << endl; // 3
cout << "seg.set(2, 5)" << endl;
seg.set(2, 5);
cout << "seg[2] = " << seg.get(2) << endl; // 5
cout << "sum = " << seg.fold_all() << endl; // 17
}
最小値を扱うセグメント木の例です。
#include <bits/stdc++.h>
#include "DataStructure/SegmentTree.hpp"
using namespace std;
constexpr int INF = 1<<30;
int main() {
// 最小値を扱うセグメント木
SegmentTree<int> seg(5, numeric_limits<int>::max(), [](auto x, auto y) { return min(x, y); });
seg.set(0, 1);
seg.set(1, 5);
seg.set(2, 3);
seg.set(3, 7);
seg.set(4, 0);
// [1, 5, 3, 7, 0]
cout << "N = " << seg.size() << endl; // 5 (= N)
cout << "min = " << seg.fold_all() << endl; // 0
cout << "min[2, 4) = " << seg.fold(2, 4) << endl; // 3
cout << "min[3, 3) = " << seg.fold(3, 3) << endl; // INF (= id_elem)
// A[1..] の中で 2 以下の値でもっとも左側の index を求める (index = r)
int r = seg.max_right(1, [](int x) { return x > 2; });
cout << "r = " << r << endl; // 4
cout << "sum[1, r) = " << seg.fold(1, r) << " (> 2)" << endl; // 3
// A[..4] で最小値が 0 以上の最小の index を求める (index = l)
int l = seg.min_left(5, [](int x) { return x >= 0; });
cout << "l = " << l << endl; // 0
cout << "sum[l, 5) = " << seg.fold(l, 5) << " (>= 0)" << endl; // 0
}
よく使いそうなセグメント木の定義をいくつか載せておきます。
xor
セグメント木
SegmentTree<int> seg(N, 0, [](auto x, auto y) { return x ^ y; });
一次関数セグメント木。mint
は ModInt
構造体の略記です。$Ax+B$ の $A$ が first
に、$B$ が second
に対応しています。演算は一次関数の合成です。
using pmm = std::pair<mint, mint>;
SegmentTree<pmm> seg(N, pmm(1, 0), [](const auto & a, const auto & b) -> pmm {
return {a.first * b.first, b.first * a.second + b.second};
});
内部で扱っているノードの数を n_
とし、ノードは 1-indexed で管理しています。
min_left
も同様なので max_right
についてのみ書きます。
$A_l, (\ldots)$ を含むノード ($A_{l-1}$ は含まない) 最も根側のノードまで上りその値を加えて右上の部分木に移動することを繰り返すことにより $g(fold(l, r)) = false$ となるような $r$ を求めることが可能です。
上っている途中で $g(fold(l, r))$ が満たさなくなったら $r$ の部分木の中に解が存在するはずです。$g(fold(l, r の左部分木の右端 + 1))$ が $false$ ならば左部分木中に、 $true$ ならば右部分木中に解があると分かります。
(l & -l)
で $l$ の最下位ビットを取り出すことができます。これを利用して $2$ べきであるかの判定を行っています。
TODO: max_right
, min_left
の test を追加する
TODO: InputIterator
コンストラクタに変更
TODO: セグ木構築のための余計な vector の作成を避けたいので move に対応したい
2020/04/08: https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
2020/09/13: AC Library
#ifndef INCLUDE_GUARD_SEGMENT_TREE_HPP
#define INCLUDE_GUARD_SEGMENT_TREE_HPP
#include <vector>
#include <algorithm>
#include <cassert>
#include <functional>
/**
* @brief https://tkmst201.github.io/Library/DataStructure/SegmentTree.hpp
*/
template<typename T>
struct SegmentTree {
using value_type = T;
using const_reference = const value_type &;
using F = std::function<value_type (const_reference, const_reference)>;
using size_type = std::size_t;
private:
size_type n, n_;
value_type id_elem;
F f;
std::vector<value_type> node;
public:
SegmentTree() = default;
SegmentTree(size_type n, const_reference id_elem, const F & f)
: n(n), id_elem(id_elem), f(f) {
n_ = 1;
while (n_ < n) n_ <<= 1;
node.assign(2 * n_, id_elem);
}
SegmentTree(const std::vector<value_type> & v, const_reference id_elem, const F & f)
: SegmentTree(v.size(), id_elem, f) {
for (size_type i = 0; i < v.size(); ++i) node[i + n_] = v[i];
for (size_type i = n_ - 1; i > 0; --i) node[i] = f(node[i << 1], node[i << 1 | 1]);
}
size_type size() const noexcept {
return n;
}
void set(size_type i, const_reference x) noexcept {
assert(i < size());
node[i += n_] = x;
while (i > 1) {
i >>= 1;
node[i] = f(node[i << 1], node[i << 1 | 1]);
}
}
const_reference get(size_type i) const noexcept {
assert(i < size());
return node[i + n_];
}
value_type fold(size_type l, size_type r) const noexcept {
assert(l <= r);
assert(r <= size());
value_type lv = id_elem, rv = id_elem;
for (l += n_, r += n_; l < r; l >>= 1, r >>= 1) {
if (l & 1) lv = f(lv, node[l++]);
if (r & 1) rv = f(node[r - 1], rv);
}
return f(lv, rv);
}
value_type fold_all() const noexcept {
return node[1];
}
size_type max_right(size_type l, std::function<bool (const_reference)> g) const noexcept {
assert(l <= size());
assert(g(id_elem));
if (l == size()) return size();
l += n_;
value_type sum = id_elem;
while (true) {
while (~l & 1) l >>= 1;
const value_type nex_sum = f(sum, node[l]);
if (g(nex_sum)) { sum = nex_sum; ++l; }
else break;
if ((l & -l) == l) return size();
}
while (l < n_) {
const value_type nex_sum = f(sum, node[l << 1]);
l <<= 1;
if (g(nex_sum)) { sum = nex_sum; l |= 1; }
}
return l - n_;
}
size_type min_left(size_type r, std::function<bool (const_reference)> g) const noexcept {
assert(r <= size());
assert(g(id_elem));
if (r == 0) return 0;
r += n_;
value_type sum = id_elem;
while (true) {
--r;
while (r > 1 && (r & 1)) r >>= 1;
const value_type nex_sum = f(node[r], sum);
if (g(nex_sum)) sum = nex_sum;
else break;
if ((r & -r) == r) return 0;
}
while (r < n_) {
const value_type nex_sum = f(node[r << 1 | 1], sum);
r <<= 1;
if (!g(nex_sum)) r |= 1;
else sum = nex_sum;
}
return r + 1 - n_;
}
};
#endif // INCLUDE_GUARD_SEGMENT_TREE_HPP
#line 1 "DataStructure/SegmentTree.hpp"
#include <vector>
#include <algorithm>
#include <cassert>
#include <functional>
/**
* @brief https://tkmst201.github.io/Library/DataStructure/SegmentTree.hpp
*/
template<typename T>
struct SegmentTree {
using value_type = T;
using const_reference = const value_type &;
using F = std::function<value_type (const_reference, const_reference)>;
using size_type = std::size_t;
private:
size_type n, n_;
value_type id_elem;
F f;
std::vector<value_type> node;
public:
SegmentTree() = default;
SegmentTree(size_type n, const_reference id_elem, const F & f)
: n(n), id_elem(id_elem), f(f) {
n_ = 1;
while (n_ < n) n_ <<= 1;
node.assign(2 * n_, id_elem);
}
SegmentTree(const std::vector<value_type> & v, const_reference id_elem, const F & f)
: SegmentTree(v.size(), id_elem, f) {
for (size_type i = 0; i < v.size(); ++i) node[i + n_] = v[i];
for (size_type i = n_ - 1; i > 0; --i) node[i] = f(node[i << 1], node[i << 1 | 1]);
}
size_type size() const noexcept {
return n;
}
void set(size_type i, const_reference x) noexcept {
assert(i < size());
node[i += n_] = x;
while (i > 1) {
i >>= 1;
node[i] = f(node[i << 1], node[i << 1 | 1]);
}
}
const_reference get(size_type i) const noexcept {
assert(i < size());
return node[i + n_];
}
value_type fold(size_type l, size_type r) const noexcept {
assert(l <= r);
assert(r <= size());
value_type lv = id_elem, rv = id_elem;
for (l += n_, r += n_; l < r; l >>= 1, r >>= 1) {
if (l & 1) lv = f(lv, node[l++]);
if (r & 1) rv = f(node[r - 1], rv);
}
return f(lv, rv);
}
value_type fold_all() const noexcept {
return node[1];
}
size_type max_right(size_type l, std::function<bool (const_reference)> g) const noexcept {
assert(l <= size());
assert(g(id_elem));
if (l == size()) return size();
l += n_;
value_type sum = id_elem;
while (true) {
while (~l & 1) l >>= 1;
const value_type nex_sum = f(sum, node[l]);
if (g(nex_sum)) { sum = nex_sum; ++l; }
else break;
if ((l & -l) == l) return size();
}
while (l < n_) {
const value_type nex_sum = f(sum, node[l << 1]);
l <<= 1;
if (g(nex_sum)) { sum = nex_sum; l |= 1; }
}
return l - n_;
}
size_type min_left(size_type r, std::function<bool (const_reference)> g) const noexcept {
assert(r <= size());
assert(g(id_elem));
if (r == 0) return 0;
r += n_;
value_type sum = id_elem;
while (true) {
--r;
while (r > 1 && (r & 1)) r >>= 1;
const value_type nex_sum = f(node[r], sum);
if (g(nex_sum)) sum = nex_sum;
else break;
if ((r & -r) == r) return 0;
}
while (r < n_) {
const value_type nex_sum = f(node[r << 1 | 1], sum);
r <<= 1;
if (!g(nex_sum)) r |= 1;
else sum = nex_sum;
}
return r + 1 - n_;
}
};