This documentation is automatically generated by online-judge-tools/verification-helper
#include "DataStructure/AVL_Tree.hpp"
順序付き集合を扱う AVL 木です。 要素数 $N$ に対し、基本的な操作を $\Theta(\log{N})$ で行うことができます。
AVL_Tree()
AVL_Tree(const AVL_Tree & rhs)
rhs
$|)$ rhs
をコピーAVL_Tree(AVL_Tree && rhs)
rhs
をムーブbool empty()
size_t size()
void clear()
std::vector<T> enumerate()
Node * begin()
Node * end()
Node * insert(const T & x)
Node * erase(const T & x)
Node * erase(Node * q)
Node * find(const T & x)
Node * lower_bound(const T & x)
Node * upper_bound(const T & x)
size_t count_less_than(const T & x)
size_t count_less_equal(const T & x)
size_t count_greater_than(const T & x)
size_t count_greater_equal(const T & x)
size_t count(const T & x)
Node * k_th_smallest(uint32_t k)
Node * k_th_largest(uint32_t k)
Node * next(Node * q)
Node * prev(Node * q)
以降、要素数を $N$ とします。
制約
T
は比較可能空の AVL 木を作成します。
計算量
rhs
をコピーします。
計算量
rhs
$|)$rhs
をムーブします。
ムーブ後の rhs
は空の AVL 木となります。
計算量
すべての要素は値の昇順に、同じ値同士は追加した順に順序付けされているとします。
AVL 木が空であるか判定します。 空なら $true$ 、そうでないなら $false$ を返します。
計算量
要素数 $N$ を返します。
計算量
すべての要素を削除します。
計算量
すべての要素を昇順に列挙します。
計算量
先頭の要素のポインタを返します。
計算量
最後の要素の次の位置を示すポインタを返します。
計算量
値 $x$ を追加します。 追加した要素のポインタを返します。
計算量
値 $x$ と等しい要素の中で先頭の要素を削除し、その次の要素のポインタを返します。
値 $x$ が存在しない場合は何も行わず、end()
を返します。
計算量
ポインタ $q$ が指す要素を削除し、その次の要素のポインタを返します。
制約
計算量
値 $x$ と等しい要素の中で先頭の要素のポインタを返します。
そのような要素が存在しない場合は end()
を返します。
計算量
値 $x$ 以上の要素の中で先頭の要素のポインタを返します。
そのような要素が存在しない場合は end()
を返します。
計算量
値 $x$ より大きい要素の中で先頭の要素のポインタを返します。
そのような要素が存在しない場合は end()
を返します。
計算量
値 $x$ 未満の要素の個数を返します。
計算量
値 $x$ 以下の要素の個数を返します。
計算量
値 $x$ より大きい要素の個数を返します。
計算量
値 $x$ 以上の要素の個数を返します。
計算量
値 $x$ と等しい要素の個数を返します。
計算量
先頭から $k$ 番目の要素のポインタを返します。
$k = 0$ またはそのような要素が存在しない場合は end()
を返します。
制約
計算量
後ろから $k$ 番目の要素のポインタを返します。
$k = 0$ またはそのような要素が存在しない場合は end()
を返します。
制約
計算量
ポインタ $q$ が指す要素の次の要素のポインタを返します。
$q =$ end()
の場合は begin()
を返します。
制約
計算量
ポインタ $q$ が指す要素の $1$ つ前の要素のポインタを返します。
$q =$ begin()
の場合は end()
を返します。
制約
計算量
#include <bits/stdc++.h>
#include "DataStructure/AVL_Tree.hpp"
using namespace std;
int main() {
AVL_Tree<int> avl;
cout << "size = " << avl.size() << endl; // 0
for (int i : {1, 2, 2, 3, 4, 5, 6, 6, 6, 7}) avl.insert(i);
cout << "size = " << avl.size() << endl; // 10
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // 1 2 2 3 4 5 6 6 6 7
cout << "min = " << avl.begin()->val << endl; // 1
cout << "max = " << avl.prev(avl.end())->val << endl; // 7
auto it = avl.find(2);
cout << "find(2) : " << it->val << endl; // 2
cout << "next : " << avl.next(it)->val << endl; // 2
cout << "prev : " << avl.prev(it)->val << endl; // 1
avl.erase(it);
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 6 7
avl.erase(6);
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
cout << "lower_bound(-1) = " << avl.lower_bound(-1)->val << endl; // 1
cout << "upper_bound(6) = " << avl.upper_bound(6)->val << endl; // 7
cout << "upper_bound(7) == end() : " << boolalpha << (avl.upper_bound(7) == avl.end()) << endl; // true
cout << "count_less_than(3) = " << avl.count_less_than(3) << endl; // 2
cout << "count_less_equal(4) = " << avl.count_less_equal(4) << endl; // 4
cout << "count_greater_than(6) = " << avl.count_greater_than(6) << endl; // 1
cout << "count_greater_equal(6) = " << avl.count_greater_equal(6) << endl; // 3
cout << "count(6) = " << avl.count(6) << endl; // 2
cout << "count_greater_than(7) = " << avl.count_greater_than(7) << endl; // 0
cout << "k_th_smallest(3) = " << avl.k_th_smallest(3)->val << endl; // 3
cout << "k_th_largest(4) = " << avl.k_th_largest(4)->val << endl; // 5
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
{
AVL_Tree<int> avl2 = avl;
for (int i : avl2.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
}
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
{
AVL_Tree<int> avl2 = std::move(avl);
for (int i : avl2.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
}
for (int i : avl.enumerate()) cout << i << " "; cout << '\n'; // empty
cout << "empty = " << boolalpha << avl.empty() << endl; // true
}
ここでは、根と葉のパスに含まれる辺の個数の最大値を AVL 木の高さと定義し、頂点数 $N$ の AVL 木の高さがどれくらいになるか調べます。
$C_i :=$ 高さが $i$ である AVL 木の最小頂点数 と定義します。 AVL 木の平衡条件より左部分木と右部分木の高さの差は $1$ 以下で、高さが低い方が最小頂点数は少ないので次のような関係式が成り立ちます。
\[C_0 = 1, C_1 = 2, C_i = C_{i-1} + C_{i-2} + 1\ (i \geq 2)\]フィボナッチ数列 $F_0 = 1, F_1 = 1, F_i = F_{i-1} + F_{i-2}\ (i \geq 2)$ と比較すると $C_i \geq F_i$ より、頂点数 $N$ に対して高さは $\mathcal{O}(\log{N})$ であることが分かります。
また、高さ $h$ に対して頂点数が最大となるのは完全二分木 (葉以外の頂点がすべて $2$ つの子を持つ) となっているときで、頂点数 $N$ に対して最小の高さは $\Omega(\log{N})$ です。
以上より、頂点数 $N$ の AVL 木の高さは $\Theta(\log{N})$ であることが分かりました。
TODO: イテレータを実装(アドレス解決演算子やアロー演算子で val
にアクセスしたい)
TODO: size
, height
を子ではなく自身で持つように変更(必要な空間が減る)
2019/11/19: https://ja.wikipedia.org/wiki/AVL%E6%9C%A8
#ifndef INCLUDE_GUARD_AVL_TREE_HPP
#define INCLUDE_GUARD_AVL_TREE_HPP
#include <algorithm>
#include <cstdint>
#include <vector>
#include <utility>
#include <stack>
/**
* @brief https://tkmst201.github.io/Library/DataStructure/AVL_Tree.hpp
*/
template<typename T>
struct AVL_Tree {
using size_type = std::size_t;
using value_type = T;
using const_reference = const value_type &;
private:
using uint32 = std::uint32_t;
using int8 = std::int8_t;
public:
struct Node;
using node_ptr = Node *;
using const_ptr = const Node * const;
struct Node {
value_type val;
node_ptr par, child[2] {nullptr, nullptr};
bool isr;
int8 height[2] {};
uint32 size[2] {};
Node(const_reference val, node_ptr par, bool isr) : val(val), par(par), isr(isr) {}
};
private:
size_type n = 0;
node_ptr root = nullptr;
node_ptr e_ptr[2] {nullptr, nullptr};
public:
AVL_Tree() = default;
AVL_Tree(const AVL_Tree & rhs) {
*this = rhs;
}
AVL_Tree(AVL_Tree && rhs) {
*this = std::forward<AVL_Tree>(rhs);
}
~AVL_Tree() {
clear();
}
AVL_Tree & operator =(const AVL_Tree & rhs) {
if (this != &rhs) {
clear();
auto dfs = [](auto self, const_ptr q, node_ptr r) -> node_ptr {
if (!q) return nullptr;
node_ptr res = new Node(q->val, r, q->isr);
for (int i = 0; i < 2; ++i) {
res->height[i] = q->height[i];
res->size[i] = q->size[i];
res->child[i] = self(self, q->child[i], res);
}
return res;
};
root = dfs(dfs, rhs.root, nullptr);
n = rhs.n;
e_ptr[0] = e_ptr[1] = root;
if (root) for (int i = 0; i < 2; ++i) while (e_ptr[i]->child[i]) e_ptr[i] = e_ptr[i]->child[i];
}
return *this;
}
AVL_Tree & operator =(AVL_Tree && rhs) {
if (this != &rhs) {
clear();
n = rhs.n;
rhs.n = 0;
root = rhs.root;
rhs.root = nullptr;
std::copy(rhs.e_ptr, rhs.e_ptr + 2, e_ptr);
std::fill(rhs.e_ptr, rhs.e_ptr + 2, nullptr);
}
return *this;
}
bool empty() const noexcept {
return size() == 0;
}
size_type size() const noexcept {
return n;
}
void clear() {
if (!root) return;
std::stack<node_ptr> stk;
stk.emplace(root);
while (!stk.empty()) {
node_ptr node = stk.top();
stk.pop();
if (node->child[0]) stk.emplace(node->child[0]);
if (node->child[1]) stk.emplace(node->child[1]);
delete node;
}
n = 0;
root = nullptr;
std::fill(e_ptr, e_ptr + 2, nullptr);
}
std::vector<value_type> enumerate() const {
std::vector<value_type> elements;
elements.reserve(size());
auto dfs = [&elements](auto self, const_ptr q) -> void {
if (!q) return;
self(self, q->child[0]);
elements.emplace_back(q->val);
self(self, q->child[1]);
};
dfs(dfs, root);
return elements;
}
node_ptr begin() const noexcept {
return e_ptr[0];
}
node_ptr end() const noexcept {
return nullptr;
}
node_ptr insert(const_reference x) {
node_ptr q = root, r = nullptr;
bool ef[2] {}, d = false;
while (q) {
r = q;
d = q->val <= x;
q = q->child[d];
ef[!d] = true;
}
q = new Node(x, r, d);
++n;
if (!ef[0]) e_ptr[0] = q;
if (!ef[1]) e_ptr[1] = q;
if (r) {
r->size[d] = 1;
r->height[d] = 1;
r->child[d] = q;
update(r);
}
else root = q;
return q;
}
node_ptr erase(const_reference x) noexcept {
node_ptr q = find(x);
if (q == end()) return end();
return erase(q);
}
node_ptr erase(node_ptr q) noexcept {
if (!q) return end();
const node_ptr ret = next(q);
if (q->child[0] && q->child[1]) {
node_ptr p = q->child[0];
while (p->child[1]) p = p->child[1];
q->val = std::move(p->val);
q = p;
}
if (e_ptr[0] == q) e_ptr[0] = next(q);
if (e_ptr[1] == q) e_ptr[1] = prev(q);
const node_ptr r = q->par;
if (q->child[0] || q->child[1]) {
const node_ptr p = q->child[0] ? q->child[0] : q->child[1];
if (r) {
r->size[q->isr] = q->size[p->isr];
r->height[q->isr] = q->height[p->isr];
r->child[q->isr] = p;
p->par = r;
p->isr = q->isr;
}
else {
p->par = nullptr;
root = p;
}
}
else if (r) {
r->size[q->isr] = 0;
r->height[q->isr] = 0;
r->child[q->isr] = nullptr;
}
else root = nullptr;
delete q;
--n;
if (r) update(r);
return ret;
}
node_ptr find(const_reference x) const noexcept {
const node_ptr q = lower_bound(x);
if (q != end() && q->val != x) return end();
return q;
}
node_ptr lower_bound(const_reference x) const noexcept {
node_ptr q = root;
if (!q) return end();
while (q->child[q->val < x]) q = q->child[q->val < x];
if (q->val < x) q = next(q);
return q;
}
node_ptr upper_bound(const_reference x) const noexcept {
node_ptr q = root;
if (!q) return end();
while (q->child[q->val <= x]) q = q->child[q->val <= x];
if (q->val <= x) q = next(q);
return q;
}
size_type count_less_than(const_reference x) const noexcept {
size_type res = 0;
node_ptr q = root;
while (q != nullptr) {
bool r = q->val < x;
if (r) res += q->size[0] + 1;
q = q->child[r];
}
return res;
}
size_type count_less_equal(const_reference x) const noexcept {
size_type res = 0;
node_ptr q = root;
while (q != nullptr) {
bool r = q->val <= x;
if (r) res += q->size[0] + 1;
q = q->child[r];
}
return res;
}
size_type count_greater_than(const_reference x) const noexcept {
return size() - count_less_equal(x);
}
size_type count_greater_equal(const_reference x) const noexcept {
return size() - count_less_than(x);
}
size_type count(const_reference x) const noexcept {
return count_less_equal(x) - count_less_than(x);
}
node_ptr k_th_smallest(uint32 k) const noexcept {
if (k == 0 || n < k) return end();
node_ptr q = root;
while (k != q->size[0] + 1) {
if (k > q->size[0] + 1) k -= q->size[0] + 1, q = q->child[1];
else q = q->child[0];
}
return q;
}
node_ptr k_th_largest(uint32 k) const noexcept {
if (k == 0 || n < k) return end();
return k_th_smallest(n - k + 1);
}
node_ptr next(node_ptr q) const noexcept {
return move(q, true);
}
node_ptr prev(node_ptr q) const noexcept {
return move(q, false);
}
private:
node_ptr rotate(node_ptr q, bool d) noexcept {
node_ptr r = q->par, p = q->child[!d], b = p->child[d];
(r ? r->child[q->isr] : root) = p;
q->child[!d] = b;
p->child[d] = q;
if (b) {
b->par = q;
b->isr = !d;
}
p->par = r;
p->isr = q->isr;
q->par = p;
q->isr = d;
q->size[!d] = p->size[d];
q->height[!d] = p->height[d];
p->size[d] = q->size[0] + q->size[1] + 1;
p->height[d] = std::max(q->height[0], q->height[1]) + 1;
return p;
}
void update(node_ptr q) noexcept {
bool done = false;
while (true) {
if (!done && std::abs(q->height[0] - q->height[1]) > 1) {
const bool d = q->height[0] > q->height[1];
const node_ptr p = q->child[!d];
if (p->height[!d] < p->height[d]) rotate(p, !d);
q = rotate(q, d);
done = true;
}
const node_ptr r = q->par;
if (!r) break;
r->size[q->isr] = q->size[0] + q->size[1] + 1;
r->height[q->isr] = std::max(q->height[0], q->height[1]) + 1;
q = r;
}
}
node_ptr move(node_ptr q, bool d) const noexcept {
if (q == end()) return e_ptr[!d];
if (q == begin() && !d) return end();
if (q->child[d]) for (q = q->child[d]; q->child[!d]; q = q->child[!d]);
else {
while (q && (d ^ !q->isr)) q = q->par;
if (q) q = q->par;
}
return q;
}
};
#endif // INCLUDE_GUARD_AVL_TREE_HPP
#line 1 "DataStructure/AVL_Tree.hpp"
#include <algorithm>
#include <cstdint>
#include <vector>
#include <utility>
#include <stack>
/**
* @brief https://tkmst201.github.io/Library/DataStructure/AVL_Tree.hpp
*/
template<typename T>
struct AVL_Tree {
using size_type = std::size_t;
using value_type = T;
using const_reference = const value_type &;
private:
using uint32 = std::uint32_t;
using int8 = std::int8_t;
public:
struct Node;
using node_ptr = Node *;
using const_ptr = const Node * const;
struct Node {
value_type val;
node_ptr par, child[2] {nullptr, nullptr};
bool isr;
int8 height[2] {};
uint32 size[2] {};
Node(const_reference val, node_ptr par, bool isr) : val(val), par(par), isr(isr) {}
};
private:
size_type n = 0;
node_ptr root = nullptr;
node_ptr e_ptr[2] {nullptr, nullptr};
public:
AVL_Tree() = default;
AVL_Tree(const AVL_Tree & rhs) {
*this = rhs;
}
AVL_Tree(AVL_Tree && rhs) {
*this = std::forward<AVL_Tree>(rhs);
}
~AVL_Tree() {
clear();
}
AVL_Tree & operator =(const AVL_Tree & rhs) {
if (this != &rhs) {
clear();
auto dfs = [](auto self, const_ptr q, node_ptr r) -> node_ptr {
if (!q) return nullptr;
node_ptr res = new Node(q->val, r, q->isr);
for (int i = 0; i < 2; ++i) {
res->height[i] = q->height[i];
res->size[i] = q->size[i];
res->child[i] = self(self, q->child[i], res);
}
return res;
};
root = dfs(dfs, rhs.root, nullptr);
n = rhs.n;
e_ptr[0] = e_ptr[1] = root;
if (root) for (int i = 0; i < 2; ++i) while (e_ptr[i]->child[i]) e_ptr[i] = e_ptr[i]->child[i];
}
return *this;
}
AVL_Tree & operator =(AVL_Tree && rhs) {
if (this != &rhs) {
clear();
n = rhs.n;
rhs.n = 0;
root = rhs.root;
rhs.root = nullptr;
std::copy(rhs.e_ptr, rhs.e_ptr + 2, e_ptr);
std::fill(rhs.e_ptr, rhs.e_ptr + 2, nullptr);
}
return *this;
}
bool empty() const noexcept {
return size() == 0;
}
size_type size() const noexcept {
return n;
}
void clear() {
if (!root) return;
std::stack<node_ptr> stk;
stk.emplace(root);
while (!stk.empty()) {
node_ptr node = stk.top();
stk.pop();
if (node->child[0]) stk.emplace(node->child[0]);
if (node->child[1]) stk.emplace(node->child[1]);
delete node;
}
n = 0;
root = nullptr;
std::fill(e_ptr, e_ptr + 2, nullptr);
}
std::vector<value_type> enumerate() const {
std::vector<value_type> elements;
elements.reserve(size());
auto dfs = [&elements](auto self, const_ptr q) -> void {
if (!q) return;
self(self, q->child[0]);
elements.emplace_back(q->val);
self(self, q->child[1]);
};
dfs(dfs, root);
return elements;
}
node_ptr begin() const noexcept {
return e_ptr[0];
}
node_ptr end() const noexcept {
return nullptr;
}
node_ptr insert(const_reference x) {
node_ptr q = root, r = nullptr;
bool ef[2] {}, d = false;
while (q) {
r = q;
d = q->val <= x;
q = q->child[d];
ef[!d] = true;
}
q = new Node(x, r, d);
++n;
if (!ef[0]) e_ptr[0] = q;
if (!ef[1]) e_ptr[1] = q;
if (r) {
r->size[d] = 1;
r->height[d] = 1;
r->child[d] = q;
update(r);
}
else root = q;
return q;
}
node_ptr erase(const_reference x) noexcept {
node_ptr q = find(x);
if (q == end()) return end();
return erase(q);
}
node_ptr erase(node_ptr q) noexcept {
if (!q) return end();
const node_ptr ret = next(q);
if (q->child[0] && q->child[1]) {
node_ptr p = q->child[0];
while (p->child[1]) p = p->child[1];
q->val = std::move(p->val);
q = p;
}
if (e_ptr[0] == q) e_ptr[0] = next(q);
if (e_ptr[1] == q) e_ptr[1] = prev(q);
const node_ptr r = q->par;
if (q->child[0] || q->child[1]) {
const node_ptr p = q->child[0] ? q->child[0] : q->child[1];
if (r) {
r->size[q->isr] = q->size[p->isr];
r->height[q->isr] = q->height[p->isr];
r->child[q->isr] = p;
p->par = r;
p->isr = q->isr;
}
else {
p->par = nullptr;
root = p;
}
}
else if (r) {
r->size[q->isr] = 0;
r->height[q->isr] = 0;
r->child[q->isr] = nullptr;
}
else root = nullptr;
delete q;
--n;
if (r) update(r);
return ret;
}
node_ptr find(const_reference x) const noexcept {
const node_ptr q = lower_bound(x);
if (q != end() && q->val != x) return end();
return q;
}
node_ptr lower_bound(const_reference x) const noexcept {
node_ptr q = root;
if (!q) return end();
while (q->child[q->val < x]) q = q->child[q->val < x];
if (q->val < x) q = next(q);
return q;
}
node_ptr upper_bound(const_reference x) const noexcept {
node_ptr q = root;
if (!q) return end();
while (q->child[q->val <= x]) q = q->child[q->val <= x];
if (q->val <= x) q = next(q);
return q;
}
size_type count_less_than(const_reference x) const noexcept {
size_type res = 0;
node_ptr q = root;
while (q != nullptr) {
bool r = q->val < x;
if (r) res += q->size[0] + 1;
q = q->child[r];
}
return res;
}
size_type count_less_equal(const_reference x) const noexcept {
size_type res = 0;
node_ptr q = root;
while (q != nullptr) {
bool r = q->val <= x;
if (r) res += q->size[0] + 1;
q = q->child[r];
}
return res;
}
size_type count_greater_than(const_reference x) const noexcept {
return size() - count_less_equal(x);
}
size_type count_greater_equal(const_reference x) const noexcept {
return size() - count_less_than(x);
}
size_type count(const_reference x) const noexcept {
return count_less_equal(x) - count_less_than(x);
}
node_ptr k_th_smallest(uint32 k) const noexcept {
if (k == 0 || n < k) return end();
node_ptr q = root;
while (k != q->size[0] + 1) {
if (k > q->size[0] + 1) k -= q->size[0] + 1, q = q->child[1];
else q = q->child[0];
}
return q;
}
node_ptr k_th_largest(uint32 k) const noexcept {
if (k == 0 || n < k) return end();
return k_th_smallest(n - k + 1);
}
node_ptr next(node_ptr q) const noexcept {
return move(q, true);
}
node_ptr prev(node_ptr q) const noexcept {
return move(q, false);
}
private:
node_ptr rotate(node_ptr q, bool d) noexcept {
node_ptr r = q->par, p = q->child[!d], b = p->child[d];
(r ? r->child[q->isr] : root) = p;
q->child[!d] = b;
p->child[d] = q;
if (b) {
b->par = q;
b->isr = !d;
}
p->par = r;
p->isr = q->isr;
q->par = p;
q->isr = d;
q->size[!d] = p->size[d];
q->height[!d] = p->height[d];
p->size[d] = q->size[0] + q->size[1] + 1;
p->height[d] = std::max(q->height[0], q->height[1]) + 1;
return p;
}
void update(node_ptr q) noexcept {
bool done = false;
while (true) {
if (!done && std::abs(q->height[0] - q->height[1]) > 1) {
const bool d = q->height[0] > q->height[1];
const node_ptr p = q->child[!d];
if (p->height[!d] < p->height[d]) rotate(p, !d);
q = rotate(q, d);
done = true;
}
const node_ptr r = q->par;
if (!r) break;
r->size[q->isr] = q->size[0] + q->size[1] + 1;
r->height[q->isr] = std::max(q->height[0], q->height[1]) + 1;
q = r;
}
}
node_ptr move(node_ptr q, bool d) const noexcept {
if (q == end()) return e_ptr[!d];
if (q == begin() && !d) return end();
if (q->child[d]) for (q = q->child[d]; q->child[!d]; q = q->child[!d]);
else {
while (q && (d ^ !q->isr)) q = q->par;
if (q) q = q->par;
}
return q;
}
};