tkmst201's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub tkmst201/Library

:heavy_check_mark: 赤黒木 (平衡二分探索木)
(DataStructure/RedBlackTree.hpp)

概要

順序付き集合を扱う赤黒木です。 要素数 $N$ に対し、基本的な操作を $\Theta(\log{N})$ で行うことができます。
AVL 木 と比べると inserterase は速く、検索クエリは遅いです。
速度はそこまで変わらないので、より機能が多い AVL 木 を使用した方が良いです。


コンストラクタ

以降、要素数を $N$ とします。


制約


RedBlackTree()

空の赤黒木を作成します。

計算量


RedBlackTree(const RedBlackTree & rhs)

rhs をコピーします。

計算量


RedBlackTree(RedBlackTree && rhs)

rhs をムーブします。 ムーブ後の rhs は空の赤黒木となります。

計算量



メンバ関数

すべての要素は値の昇順に、同じ値同士は追加した順に順序付けされているとします。


bool empty()

赤黒木が空であるか判定します。 空なら $true$ 、そうでないなら $false$ を返します。

計算量


size_t size()

要素数 $N$ を返します。

計算量


void clear()

すべての要素を削除します。

計算量


std::vector enumerate()

すべての要素を昇順に列挙します。

計算量


Node * begin()

先頭の要素のポインタを返します。

計算量


Node * end()

最後の要素の次の位置を示すポインタを返します。

計算量


Node * insert(const T & x)

値 $x$ を追加します。 追加した要素のポインタを返します。

計算量


Node * erase(const T & x)

値 $x$ と等しい要素の中で先頭の要素を削除し、その次の要素のポインタを返します。 値 $x$ が存在しない場合は何も行わず、end() を返します。

計算量


Node * erase(Node * q)

ポインタ $q$ が指す要素を削除し、その次の要素のポインタを返します。

制約

計算量


Node * find(const T & x)

値 $x$ と等しい要素の中で先頭の要素のポインタを返します。 そのような要素が存在しない場合は end() を返します。

計算量


Node * lower_bound(const T & x)

値 $x$ 以上の要素の中で先頭の要素のポインタを返します。 そのような要素が存在しない場合は end() を返します。

計算量


Node * upper_bound(const T & x)

値 $x$ より大きい要素の中で先頭の要素のポインタを返します。 そのような要素が存在しない場合は end() を返します。

計算量


Node * next(Node * q)

ポインタ $q$ が指す要素の次の要素のポインタを返します。 $q =$ end() の場合は begin() を返します。

制約

計算量


Node * prev(Node * q)

ポインタ $q$ が指す要素の $1$ つ前の要素のポインタを返します。 $q =$ begin() の場合は end() を返します。

制約

計算量



使用例

#include <bits/stdc++.h>
#include "DataStructure/RedBlackTree.hpp"
using namespace std;

int main() {
	RedBlackTree<int> rbt;
	cout << "size = " << rbt.size() << endl; // 0
	
	for (int i : {1, 2, 2, 3, 4, 5, 6, 6, 6, 7}) rbt.insert(i);
	cout << "size = " << rbt.size() << endl; // 10
	
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // 1 2 2 3 4 5 6 6 6 7
	cout << "min = " << rbt.begin()->val << endl; // 1
	cout << "max = " << rbt.prev(rbt.end())->val << endl; // 7
	
	auto it = rbt.find(2);
	cout << "find(2) : " << it->val << endl; // 2
	cout << "next : " << rbt.next(it)->val << endl; // 2
	cout << "prev : " << rbt.prev(it)->val << endl; // 1
	
	rbt.erase(it);
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 6 7
	rbt.erase(6);
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
	
	cout << "lower_bound(-1) = " << rbt.lower_bound(-1)->val << endl; // 1
	cout << "upper_bound(6) = " << rbt.upper_bound(6)->val << endl; // 7
	cout << "upper_bound(7) == end() : " << boolalpha << (rbt.upper_bound(7) == rbt.end()) << endl; // true
	
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
	{
		RedBlackTree<int> rbt2 = rbt;
		for (int i : rbt2.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
	}
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
	{
		RedBlackTree<int> rbt2 = std::move(rbt);
		for (int i : rbt2.enumerate()) cout << i << " "; cout << '\n'; // 1 2 3 4 5 6 6 7
	}
	for (int i : rbt.enumerate()) cout << i << " "; cout << '\n'; // empty
	cout << "empty = " << boolalpha << rbt.empty() << endl; // true
}


参考

2020/08/27: https://ja.wikipedia.org/wiki/%E8%B5%A4%E9%BB%92%E6%9C%A8
2020/08/27: http://wwwa.pikara.ne.jp/okojisan/rb-tree/index.html
2020/08/27: http://www.nct9.ne.jp/m_hiroi/light/pyalgo16.html


Verified with

Code

#ifndef INCLUDE_GUARD_RED_BLACK_TREE_HPP
#define INCLUDE_GUARD_RED_BLACK_TREE_HPP

#include <algorithm>
#include <vector>
#include <utility>
#include <stack>

/**
 * @brief https://tkmst201.github.io/Library/DataStructure/RedBlackTree.hpp
 */
template<typename T>
struct RedBlackTree {
	using size_type = std::size_t;
	using value_type = T;
	using const_reference = const value_type &;
	
public:
	struct Node;
	using node_ptr = Node *;
	using const_ptr = const Node * const;
	struct Node {
		value_type val;
		bool isred;
		node_ptr par;
		bool isr;
		node_ptr child[2] {nullptr, nullptr};
		Node(const_reference val, bool isred, node_ptr par, bool isr)
			: val(val), isred(isred), par(par), isr(isr) {}
	};
	
private:
	size_type n = 0;
	node_ptr root = nullptr;
	node_ptr e_ptr[2] {nullptr, nullptr};
	
public:
	RedBlackTree() = default;
	
	RedBlackTree(const RedBlackTree & rhs) {
		*this = rhs;
	}
	
	RedBlackTree(RedBlackTree && rhs) {
		*this = std::forward<RedBlackTree>(rhs);
	}
	
	~RedBlackTree() {
		clear();
	}
	
	RedBlackTree & operator =(const RedBlackTree & rhs) {
		if (this != &rhs) {
			clear();
			auto dfs = [](auto self, const_ptr q, node_ptr r) -> node_ptr {
				if (!q) return nullptr;
				node_ptr res = new Node(q->val, q->isred, r, q->isr);
				for (int i = 0; i < 2; ++i) res->child[i] = self(self, q->child[i], res);
				return res;
			};
			root = dfs(dfs, rhs.root, nullptr);
			n = rhs.n;
			e_ptr[0] = e_ptr[1] = root;
			if (root) for (int i = 0; i < 2; ++i) while (e_ptr[i]->child[i]) e_ptr[i] = e_ptr[i]->child[i];
		}
		return *this;
	}
	
	RedBlackTree & operator =(RedBlackTree && rhs) {
		if (this != &rhs) {
			clear();
			n = rhs.n;
			rhs.n = 0;
			root = rhs.root;
			rhs.root = nullptr;
			std::copy(rhs.e_ptr, rhs.e_ptr + 2, e_ptr);
			std::fill(rhs.e_ptr, rhs.e_ptr + 2, nullptr);
		}
		return *this;
	}
	
	bool empty() const noexcept {
		return size() == 0;
	}
	
	size_type size() const noexcept {
		return n;
	}
	
	void clear() {
		if (!root) return;
		std::stack<node_ptr> stk;
		stk.emplace(root);
		while (!stk.empty()) {
			node_ptr node = stk.top();
			stk.pop();
			if (node->child[0]) stk.emplace(node->child[0]);
			if (node->child[1]) stk.emplace(node->child[1]);
			delete node;
		}
		n = 0;
		root = nullptr;
		std::fill(e_ptr, e_ptr + 2, nullptr);
	}
	
	std::vector<value_type> enumerate() const {
		std::vector<value_type> elements;
		elements.reserve(size());
		auto dfs = [&elements](auto self, const_ptr q) -> void {
			if (!q) return;
			self(self, q->child[0]);
			elements.emplace_back(q->val);
			self(self, q->child[1]);
		};
		dfs(dfs, root);
		return elements;
	}
	
	node_ptr begin() const noexcept {
		return e_ptr[0];
	}
	
	node_ptr end() const noexcept {
		return nullptr;
	}
	
	node_ptr insert(const_reference x) {
		node_ptr p = nullptr, node = root;
		bool ef[2] {}, d = false;
		while (node) {
			d = node->val <= x;
			p = node;
			node = node->child[d];
			ef[!d] = true;
		}
		node = new Node{x, true, p, d};
		++n;
		if (!ef[0]) e_ptr[0] = node;
		if (!ef[1]) e_ptr[1] = node;
		if (!p) {
			root = node;
			node->isred = false;
			return node;
		}
		p->child[d] = node;
		if (!p->isred) return node;
		node_ptr g = p->par, u = g->child[!p->isr];
		if (!u) {
			if (node->isr == p->isr) {
				rotate(g, !p->isr);
				p->isred = false;
				g->isred = true;
				return node;
			}
			else {
				p->child[node->isr] = nullptr;
				g->child[!p->isr] = node;
				node->par = g;
				std::swap(g->val, node->val);
				if (e_ptr[0] == g) e_ptr[0] = node;
				if (e_ptr[1] == g) e_ptr[1] = node;
				return g;
			}
		}
		while (u->isred) {
			g->isred = true;
			p->isred = false;
			u->isred = false;
			const node_ptr cur = g;
			p = cur->par;
			if (!p) { cur->isred = false; break; }
			if (!p->isred) break;
			g = p->par;
			u = g->child[!p->isr];
			if (!u->isred) {
				if (cur->isr == p->isr) {
					rotate(g, !p->isr);
					p->isred = false;
					g->isred = true;
				}
				else {
					rotate(p, p->isr);
					rotate(g, !p->isr);
					cur->isred = false;
					g->isred = true;
				}
			}
		}
		return node;
	}
	
	node_ptr erase(const_reference x) noexcept {
		node_ptr q = find(x);
		if (q == end()) return end();
		return erase(q);
	}
	
	node_ptr erase(node_ptr q) noexcept {
		if (!q) return end();
		node_ptr ret = next(q);
		if (q->child[0]) {
			node_ptr p = q->child[0];
			while (p->child[1]) p = p->child[1];
			q->val = std::move(p->val);
			q = p;
		}
		if (e_ptr[0] == q) e_ptr[0] = next(q);
		if (e_ptr[1] == q) e_ptr[1] = prev(q);
		node_ptr ch = q->child[0] ? q->child[0] : q->child[1];
		if (ch) {
			q->val = std::move(ch->val);
			if (e_ptr[0] == ch) e_ptr[0] = q;
			if (e_ptr[1] == ch) e_ptr[1] = q;
			if (ret == ch) ret = q;
			q = ch;
		}
		node_ptr p = q->par;
		bool isred = q->isred, isr = q->isr;
		(p ? p->child[q->isr] : root) = nullptr;
		delete q;
		--n;
		if (!root || isred) return ret;
		while (p) {
			node_ptr c = p->child[!isr];
			if (c->isred) {
				rotate(p, isr);
				c->isred = false;
				p->isred = true;
				c = p->child[!isr];
			}
			node_ptr g = c->child[!isr];
			if (g && g->isred) {
				rotate(p, isr);
				c->isred = p->isred;
				g->isred = false;
				p->isred = false;
				break;
			}
			g = c->child[isr];
			if (g && g->isred) {
				rotate(c, !isr);
				rotate(p, isr);
				g->isred = p->isred;
				p->isred = false;
				break;
			}
			c->isred = true;
			if (p->isred) {
				p->isred = false;
				break;
			}
			isr = p->isr;
			p = p->par;
		}
		if (root) root->isred = false;
		return ret;
	}
	
	node_ptr find(const_reference x) const noexcept {
		const node_ptr q = lower_bound(x);
		if (q != end() && q->val != x) return end();
		return q;
	}
	
	node_ptr lower_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val < x]) q = q->child[q->val < x];
		if (q->val < x) q = next(q);
		return q;
	}
	
	node_ptr upper_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val <= x]) q = q->child[q->val <= x];
		if (q->val <= x) q = next(q);
		return q;
	}
	
	node_ptr next(node_ptr q) const noexcept {
		return move(q, true);
	}
	
	node_ptr prev(node_ptr q) const noexcept {
		return move(q, false);
	}
	
private:
	void rotate(node_ptr q, bool d) noexcept {
		node_ptr r = q->par, p = q->child[!d], b = p->child[d];
		(r ? r->child[q->isr] : root) = p;
		q->child[!d] = b;
		p->child[d] = q;
		if (b) {
			b->par = q;
			b->isr = !d;
		}
		p->par = r;
		p->isr = q->isr;
		q->par = p;
		q->isr = d;
	}
	
	node_ptr move(node_ptr q, bool d) const noexcept {
		if (q == end()) return e_ptr[!d];
		if (q == begin() && !d) return end();
		if (q->child[d]) for (q = q->child[d]; q->child[!d]; q = q->child[!d]);
		else {
			while (q && (d ^ !q->isr)) q = q->par;
			if (q) q = q->par;
		}
		return q;
	}
};

#endif // INCLUDE_GUARD_RED_BLACK_TREE_HPP
#line 1 "DataStructure/RedBlackTree.hpp"



#include <algorithm>
#include <vector>
#include <utility>
#include <stack>

/**
 * @brief https://tkmst201.github.io/Library/DataStructure/RedBlackTree.hpp
 */
template<typename T>
struct RedBlackTree {
	using size_type = std::size_t;
	using value_type = T;
	using const_reference = const value_type &;
	
public:
	struct Node;
	using node_ptr = Node *;
	using const_ptr = const Node * const;
	struct Node {
		value_type val;
		bool isred;
		node_ptr par;
		bool isr;
		node_ptr child[2] {nullptr, nullptr};
		Node(const_reference val, bool isred, node_ptr par, bool isr)
			: val(val), isred(isred), par(par), isr(isr) {}
	};
	
private:
	size_type n = 0;
	node_ptr root = nullptr;
	node_ptr e_ptr[2] {nullptr, nullptr};
	
public:
	RedBlackTree() = default;
	
	RedBlackTree(const RedBlackTree & rhs) {
		*this = rhs;
	}
	
	RedBlackTree(RedBlackTree && rhs) {
		*this = std::forward<RedBlackTree>(rhs);
	}
	
	~RedBlackTree() {
		clear();
	}
	
	RedBlackTree & operator =(const RedBlackTree & rhs) {
		if (this != &rhs) {
			clear();
			auto dfs = [](auto self, const_ptr q, node_ptr r) -> node_ptr {
				if (!q) return nullptr;
				node_ptr res = new Node(q->val, q->isred, r, q->isr);
				for (int i = 0; i < 2; ++i) res->child[i] = self(self, q->child[i], res);
				return res;
			};
			root = dfs(dfs, rhs.root, nullptr);
			n = rhs.n;
			e_ptr[0] = e_ptr[1] = root;
			if (root) for (int i = 0; i < 2; ++i) while (e_ptr[i]->child[i]) e_ptr[i] = e_ptr[i]->child[i];
		}
		return *this;
	}
	
	RedBlackTree & operator =(RedBlackTree && rhs) {
		if (this != &rhs) {
			clear();
			n = rhs.n;
			rhs.n = 0;
			root = rhs.root;
			rhs.root = nullptr;
			std::copy(rhs.e_ptr, rhs.e_ptr + 2, e_ptr);
			std::fill(rhs.e_ptr, rhs.e_ptr + 2, nullptr);
		}
		return *this;
	}
	
	bool empty() const noexcept {
		return size() == 0;
	}
	
	size_type size() const noexcept {
		return n;
	}
	
	void clear() {
		if (!root) return;
		std::stack<node_ptr> stk;
		stk.emplace(root);
		while (!stk.empty()) {
			node_ptr node = stk.top();
			stk.pop();
			if (node->child[0]) stk.emplace(node->child[0]);
			if (node->child[1]) stk.emplace(node->child[1]);
			delete node;
		}
		n = 0;
		root = nullptr;
		std::fill(e_ptr, e_ptr + 2, nullptr);
	}
	
	std::vector<value_type> enumerate() const {
		std::vector<value_type> elements;
		elements.reserve(size());
		auto dfs = [&elements](auto self, const_ptr q) -> void {
			if (!q) return;
			self(self, q->child[0]);
			elements.emplace_back(q->val);
			self(self, q->child[1]);
		};
		dfs(dfs, root);
		return elements;
	}
	
	node_ptr begin() const noexcept {
		return e_ptr[0];
	}
	
	node_ptr end() const noexcept {
		return nullptr;
	}
	
	node_ptr insert(const_reference x) {
		node_ptr p = nullptr, node = root;
		bool ef[2] {}, d = false;
		while (node) {
			d = node->val <= x;
			p = node;
			node = node->child[d];
			ef[!d] = true;
		}
		node = new Node{x, true, p, d};
		++n;
		if (!ef[0]) e_ptr[0] = node;
		if (!ef[1]) e_ptr[1] = node;
		if (!p) {
			root = node;
			node->isred = false;
			return node;
		}
		p->child[d] = node;
		if (!p->isred) return node;
		node_ptr g = p->par, u = g->child[!p->isr];
		if (!u) {
			if (node->isr == p->isr) {
				rotate(g, !p->isr);
				p->isred = false;
				g->isred = true;
				return node;
			}
			else {
				p->child[node->isr] = nullptr;
				g->child[!p->isr] = node;
				node->par = g;
				std::swap(g->val, node->val);
				if (e_ptr[0] == g) e_ptr[0] = node;
				if (e_ptr[1] == g) e_ptr[1] = node;
				return g;
			}
		}
		while (u->isred) {
			g->isred = true;
			p->isred = false;
			u->isred = false;
			const node_ptr cur = g;
			p = cur->par;
			if (!p) { cur->isred = false; break; }
			if (!p->isred) break;
			g = p->par;
			u = g->child[!p->isr];
			if (!u->isred) {
				if (cur->isr == p->isr) {
					rotate(g, !p->isr);
					p->isred = false;
					g->isred = true;
				}
				else {
					rotate(p, p->isr);
					rotate(g, !p->isr);
					cur->isred = false;
					g->isred = true;
				}
			}
		}
		return node;
	}
	
	node_ptr erase(const_reference x) noexcept {
		node_ptr q = find(x);
		if (q == end()) return end();
		return erase(q);
	}
	
	node_ptr erase(node_ptr q) noexcept {
		if (!q) return end();
		node_ptr ret = next(q);
		if (q->child[0]) {
			node_ptr p = q->child[0];
			while (p->child[1]) p = p->child[1];
			q->val = std::move(p->val);
			q = p;
		}
		if (e_ptr[0] == q) e_ptr[0] = next(q);
		if (e_ptr[1] == q) e_ptr[1] = prev(q);
		node_ptr ch = q->child[0] ? q->child[0] : q->child[1];
		if (ch) {
			q->val = std::move(ch->val);
			if (e_ptr[0] == ch) e_ptr[0] = q;
			if (e_ptr[1] == ch) e_ptr[1] = q;
			if (ret == ch) ret = q;
			q = ch;
		}
		node_ptr p = q->par;
		bool isred = q->isred, isr = q->isr;
		(p ? p->child[q->isr] : root) = nullptr;
		delete q;
		--n;
		if (!root || isred) return ret;
		while (p) {
			node_ptr c = p->child[!isr];
			if (c->isred) {
				rotate(p, isr);
				c->isred = false;
				p->isred = true;
				c = p->child[!isr];
			}
			node_ptr g = c->child[!isr];
			if (g && g->isred) {
				rotate(p, isr);
				c->isred = p->isred;
				g->isred = false;
				p->isred = false;
				break;
			}
			g = c->child[isr];
			if (g && g->isred) {
				rotate(c, !isr);
				rotate(p, isr);
				g->isred = p->isred;
				p->isred = false;
				break;
			}
			c->isred = true;
			if (p->isred) {
				p->isred = false;
				break;
			}
			isr = p->isr;
			p = p->par;
		}
		if (root) root->isred = false;
		return ret;
	}
	
	node_ptr find(const_reference x) const noexcept {
		const node_ptr q = lower_bound(x);
		if (q != end() && q->val != x) return end();
		return q;
	}
	
	node_ptr lower_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val < x]) q = q->child[q->val < x];
		if (q->val < x) q = next(q);
		return q;
	}
	
	node_ptr upper_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val <= x]) q = q->child[q->val <= x];
		if (q->val <= x) q = next(q);
		return q;
	}
	
	node_ptr next(node_ptr q) const noexcept {
		return move(q, true);
	}
	
	node_ptr prev(node_ptr q) const noexcept {
		return move(q, false);
	}
	
private:
	void rotate(node_ptr q, bool d) noexcept {
		node_ptr r = q->par, p = q->child[!d], b = p->child[d];
		(r ? r->child[q->isr] : root) = p;
		q->child[!d] = b;
		p->child[d] = q;
		if (b) {
			b->par = q;
			b->isr = !d;
		}
		p->par = r;
		p->isr = q->isr;
		q->par = p;
		q->isr = d;
	}
	
	node_ptr move(node_ptr q, bool d) const noexcept {
		if (q == end()) return e_ptr[!d];
		if (q == begin() && !d) return end();
		if (q->child[d]) for (q = q->child[d]; q->child[!d]; q = q->child[!d]);
		else {
			while (q && (d ^ !q->isr)) q = q->par;
			if (q) q = q->par;
		}
		return q;
	}
};
Back to top page